Skip to content

Commit

Permalink
remove unused parameter string_param from all model_handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
HuanzhiMao committed Jul 21, 2024
1 parent ad7c337 commit 33b9fc3
Show file tree
Hide file tree
Showing 14 changed files with 18 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def inference(self, prompt, functions, test_category):
return handler.inference(prompt, functions, test_category)
else:
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category, True)
functions = language_specific_pre_processing(functions, test_category)
if type(functions) is not list:
functions = [functions]
claude_tool = convert_to_tool(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def _get_claude_function_calling_response(self, prompt, functions, test_category
def inference(self, prompt, functions, test_category):
prompt = augment_prompt_by_languge(prompt, test_category)
if "FC" in self.model_name:
functions = language_specific_pre_processing(functions, test_category, True)
functions = language_specific_pre_processing(functions, test_category)
result, metadata = self._get_claude_function_calling_response(
prompt, functions, test_category
)
return result, metadata
else:
start = time.time()
functions = language_specific_pre_processing(
functions, test_category, False
functions, test_category
)
response = self.client.messages.create(
model=self.model_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def inference(self, prompt, functions, test_category):
if "FC" not in self.model_name:
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(
functions, test_category, False
functions, test_category
)
message = USER_PROMPT_FOR_CHAT_MODEL.format(
user_prompt=prompt, functions=str(functions)
Expand All @@ -69,7 +69,7 @@ def inference(self, prompt, functions, test_category):
result = response.text
else:
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category, True)
functions = language_specific_pre_processing(functions, test_category)
if type(functions) is not list:
functions = [functions]
message = prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non
)

def inference(self, prompt, functions, test_category):
functions = language_specific_pre_processing(functions, test_category, False)
functions = language_specific_pre_processing(functions, test_category)
if type(functions) is not list:
functions = [functions]
message = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def write(self, result, file_to_open):
f.write(json.dumps(result) + "\n")

def inference(self, prompt, functions, test_category):
functions = language_specific_pre_processing(functions, test_category, True)
functions = language_specific_pre_processing(functions, test_category)
if type(functions) is not list:
functions = [functions]
message = [{"role": "user", "content": prompt}]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _query_gemini(self, user_query, functions):

def inference(self, prompt, functions, test_category):
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category, True)
functions = language_specific_pre_processing(functions, test_category)
gemini_tool = convert_to_tool(
functions, GORILLA_TO_OPENAPI, self.model_style, test_category, True
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def inference(self, test_question, test_category, num_gpus):
for line in test_question:
prompt = augment_prompt_by_languge(line["question"], test_category)
function = language_specific_pre_processing(
line["function"], test_category, False
line["function"], test_category
)
chat_template_ques_jsons.append(
self.apply_chat_template(prompt, function, test_category)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_gorilla_response(self, prompt, functions):

def inference(self, prompt, functions, test_category):
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category, False)
functions = language_specific_pre_processing(functions, test_category)
if type(functions) is not list:
functions = [functions]
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non
def inference(self, prompt,functions,test_category):
if "FC" not in self.model_name:
prompt = augment_prompt_by_languge(prompt,test_category)
functions = language_specific_pre_processing(functions,test_category,False)
functions = language_specific_pre_processing(functions,test_category)
message = [
{
"role": "system",
Expand All @@ -51,7 +51,7 @@ def inference(self, prompt,functions,test_category):
result = response.choices[0].message.content
else:
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category, True)
functions = language_specific_pre_processing(functions, test_category)
if type(functions) is not list:
functions = [functions]
message = [{"role": "user", "content": prompt}]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _format_prompt(prompt, function, test_category):
if language_specific_prompt_augmented_str.strip():
prompt = prompt.replace(language_specific_prompt_augmented_str, "")

functions = language_specific_pre_processing(function, test_category, False)
functions = language_specific_pre_processing(function, test_category)
functions = convert_to_tool(
functions,
GORILLA_TO_OPENAPI,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non
def inference(self, prompt, functions, test_category):
prompt = augment_prompt_by_languge(prompt, test_category)
if "FC" in self.model_name:
functions = language_specific_pre_processing(functions, test_category, True)
functions = language_specific_pre_processing(functions, test_category)
tool = convert_to_tool(
functions, GORILLA_TO_OPENAPI, self.model_style, test_category, True
)
Expand Down Expand Up @@ -57,7 +57,7 @@ def inference(self, prompt, functions, test_category):
result = chat_response.choices[0].message.content
else:
functions = language_specific_pre_processing(
functions, test_category, False
functions, test_category
)
message = [
ChatMessage(role="system", content=SYSTEM_PROMPT_FOR_CHAT_MODEL),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def query(payload):

def inference(self, prompt, functions, test_category):
prompt = augment_prompt_by_languge(prompt, test_category)
functions = language_specific_pre_processing(functions, test_category, False)
functions = language_specific_pre_processing(functions, test_category)
raven_prompt = self._format_raven_function(prompt, functions)
result, metadata = self._query_raven(raven_prompt)
return result, metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, model_name, temperature=0.7, top_p=1, max_tokens=1000) -> Non
)
def inference(self, prompt, functions, test_category):
prompt = augment_prompt_by_languge(prompt,test_category)
functions = language_specific_pre_processing(functions,test_category,False)
functions = language_specific_pre_processing(functions,test_category)
message = [
{
"role": "system",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _batch_generate(
ques_json = line
prompt = augment_prompt_by_languge(ques_json["question"], test_category)
functions = language_specific_pre_processing(
ques_json["function"], test_category, False
ques_json["function"], test_category
)
prompts.append(format_prompt_func(prompt, functions, test_category))
ans_id = shortuuid.uuid()
Expand Down

0 comments on commit 33b9fc3

Please sign in to comment.