-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #28 from Y-IAB/27-instruction
[#27] Add instructions for each category
- Loading branch information
Showing
2 changed files
with
137 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
This module contains functions for generating responses using LLMs. | ||
""" | ||
|
||
import enum | ||
from random import sample | ||
|
||
import gradio as gr | ||
from litellm import completion | ||
|
||
# TODO(#1): Add more models. | ||
SUPPORTED_MODELS = [ | ||
"gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro" | ||
] | ||
|
||
|
||
class Category(enum.Enum): | ||
SUMMARIZE = "Summarize" | ||
TRANSLATE = "Translate" | ||
|
||
|
||
# TODO(#31): Let the model builders set the instruction. | ||
def get_instruction(category, source_lang, target_lang): | ||
if category == Category.SUMMARIZE.value: | ||
return "Summarize the following text in its original language." | ||
if category == Category.TRANSLATE.value: | ||
return f"Translate the following text from {source_lang} to {target_lang}." | ||
|
||
|
||
def response_generator(response: str): | ||
for part in response: | ||
content = part.choices[0].delta.content | ||
if content is None: | ||
continue | ||
|
||
# To simulate a stream, we yield each character of the response. | ||
for character in content: | ||
yield character | ||
|
||
|
||
# TODO(#29): Return results simultaneously to prevent bias from generation speed. | ||
def get_responses(user_prompt, category, source_lang, target_lang): | ||
if not category: | ||
raise gr.Error("Please select a category.") | ||
|
||
if category == Category.TRANSLATE.value and (not source_lang or | ||
not target_lang): | ||
raise gr.Error("Please select source and target languages.") | ||
|
||
models = sample(SUPPORTED_MODELS, 2) | ||
instruction = get_instruction(category, source_lang, target_lang) | ||
|
||
generators = [] | ||
for model in models: | ||
try: | ||
# TODO(#1): Allow user to set configuration. | ||
response = completion(model=model, | ||
messages=[{ | ||
"content": instruction, | ||
"role": "system" | ||
}, { | ||
"content": user_prompt, | ||
"role": "user" | ||
}], | ||
stream=True) | ||
generators.append(response_generator(response)) | ||
|
||
# TODO(#1): Narrow down the exception type. | ||
except Exception as e: # pylint: disable=broad-except | ||
print(f"Error in bot_response: {e}") | ||
raise e | ||
|
||
responses = ["", ""] | ||
|
||
# It simulates concurrent response generation from two models. | ||
while True: | ||
stop = True | ||
|
||
for i in range(len(generators)): | ||
try: | ||
yielded = next(generators[i]) | ||
|
||
if yielded is None: | ||
continue | ||
|
||
responses[i] += yielded | ||
stop = False | ||
|
||
yield responses + models + [instruction] | ||
|
||
except StopIteration: | ||
pass | ||
|
||
# TODO(#1): Narrow down the exception type. | ||
except Exception as e: # pylint: disable=broad-except | ||
print(f"Error in generator: {e}") | ||
raise e | ||
|
||
if stop: | ||
break |