Skip to content

Commit

Permalink
Merge pull request #28 from Y-IAB/27-instruction
Browse files Browse the repository at this point in the history
[#27] Add instructions for each category
  • Loading branch information
kangsuhyun-yanolja authored Feb 7, 2024
2 parents b727183 + e4bf3a0 commit 34c6b56
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 109 deletions.
146 changes: 37 additions & 109 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,143 +3,66 @@
"""

import enum
from random import sample
from uuid import uuid4

import firebase_admin
from firebase_admin import firestore
import gradio as gr
from litellm import completion

from leaderboard import build_leaderboard
import response
from response import get_responses

# TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
db_app = firebase_admin.initialize_app()
db = firestore.client()

# TODO(#1): Add more models.
SUPPORTED_MODELS = [
"gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro"
]

SUPPORTED_TRANSLATION_LANGUAGES = [
"Korean", "English", "Chinese", "Japanese", "Spanish", "French"
]


class ResponseType(enum.Enum):
SUMMARIZE = "Summarize"
TRANSLATE = "Translate"


class VoteOptions(enum.Enum):
MODEL_A = "Model A is better"
MODEL_B = "Model B is better"
TIE = "Tie"


def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
user_prompt, res_type, source_lang, target_lang):
user_prompt, instruction, category, source_lang, target_lang):
doc_id = uuid4().hex
winner = VoteOptions(vote_button).name.lower()

if res_type == ResponseType.SUMMARIZE.value:
doc = {
"id": doc_id,
"prompt": user_prompt,
"instruction": instruction,
"model_a": model_a_name,
"model_b": model_b_name,
"model_a_response": response_a,
"model_b_response": response_b,
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
}

if category == response.Category.SUMMARIZE.value:
doc_ref = db.collection("arena-summarizations").document(doc_id)
doc_ref.set({
"id": doc_id,
"prompt": user_prompt,
"model_a": model_a_name,
"model_b": model_b_name,
"model_a_response": response_a,
"model_b_response": response_b,
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
})
doc_ref.set(doc)
return

if res_type == ResponseType.TRANSLATE.value:
if category == response.Category.TRANSLATE.value:
doc_ref = db.collection("arena-translations").document(doc_id)
doc_ref.set({
"id": doc_id,
"prompt": user_prompt,
"model_a": model_a_name,
"model_b": model_b_name,
"model_a_response": response_a,
"model_b_response": response_b,
"source_language": source_lang.lower(),
"target_language": target_lang.lower(),
"winner": winner,
"timestamp": firestore.SERVER_TIMESTAMP
})


def response_generator(response: str):
for part in response:
content = part.choices[0].delta.content
if content is None:
continue

# To simulate a stream, we yield each character of the response.
for character in content:
yield character


def get_responses(user_prompt):
models = sample(SUPPORTED_MODELS, 2)

generators = []
for model in models:
try:
# TODO(#1): Allow user to set configuration.
response = completion(model=model,
messages=[{
"content": user_prompt,
"role": "user"
}],
stream=True)
generators.append(response_generator(response))

# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in bot_response: {e}")
raise e

responses = ["", ""]

# It simulates concurrent response generation from two models.
while True:
stop = True

for i in range(len(generators)):
try:
yielded = next(generators[i])

if yielded is None:
continue

responses[i] += yielded
stop = False

yield responses + models

except StopIteration:
pass

# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in generator: {e}")
raise e

if stop:
break
doc["source_lang"] = source_lang.lower()
doc["target_lang"] = target_lang.lower()
doc_ref.set(doc)


with gr.Blocks(title="Arena") as app:
with gr.Row():
response_type_radio = gr.Radio(
[response_type.value for response_type in ResponseType],
label="Response type",
info="Choose the type of response you want from the model.")
category_radio = gr.Radio(
[category.value for category in response.Category],
label="Category",
info="The chosen category determines the instruction sent to the LLMs.")

source_language = gr.Dropdown(
choices=SUPPORTED_TRANSLATION_LANGUAGES,
Expand All @@ -154,15 +77,15 @@ def get_responses(user_prompt):
interactive=True,
visible=False)

def update_language_visibility(response_type):
visible = response_type == ResponseType.TRANSLATE.value
def update_language_visibility(category):
visible = category == response.Category.TRANSLATE.value
return {
source_language: gr.Dropdown(visible=visible),
target_language: gr.Dropdown(visible=visible)
}

response_type_radio.change(update_language_visibility, response_type_radio,
[source_language, target_language])
category_radio.change(update_language_visibility, category_radio,
[source_language, target_language])

model_names = [gr.State(None), gr.State(None)]
response_boxes = [gr.State(None), gr.State(None)]
Expand All @@ -175,7 +98,7 @@ def update_language_visibility(response_type):
response_boxes[1] = gr.Textbox(label="Model B", interactive=False)

# TODO(#5): Display it only after the user submits the prompt.
# TODO(#6): Block voting if the response_type is not set.
# TODO(#6): Block voting if the category is not set.
# TODO(#6): Block voting if the user already voted.
with gr.Row():
option_a = gr.Button(VoteOptions.MODEL_A.value)
Expand All @@ -188,10 +111,15 @@ def update_language_visibility(response_type):
model_names[0] = gr.Textbox(label="Model A", interactive=False)
model_names[1] = gr.Textbox(label="Model B", interactive=False)

submit.click(get_responses, prompt, response_boxes + model_names)
instruction_state = gr.State("")

submit.click(get_responses,
[prompt, category_radio, source_language, target_language],
response_boxes + model_names + [instruction_state])

common_inputs = response_boxes + model_names + [
prompt, response_type_radio, source_language, target_language
prompt, instruction_state, category_radio, source_language,
target_language
]
option_a.click(vote, [option_a] + common_inputs)
option_b.click(vote, [option_b] + common_inputs)
Expand Down
100 changes: 100 additions & 0 deletions response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
This module contains functions for generating responses using LLMs.
"""

import enum
from random import sample

import gradio as gr
from litellm import completion

# TODO(#1): Add more models.
SUPPORTED_MODELS = [
"gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro"
]


class Category(enum.Enum):
SUMMARIZE = "Summarize"
TRANSLATE = "Translate"


# TODO(#31): Let the model builders set the instruction.
def get_instruction(category, source_lang, target_lang):
if category == Category.SUMMARIZE.value:
return "Summarize the following text in its original language."
if category == Category.TRANSLATE.value:
return f"Translate the following text from {source_lang} to {target_lang}."


def response_generator(response: str):
for part in response:
content = part.choices[0].delta.content
if content is None:
continue

# To simulate a stream, we yield each character of the response.
for character in content:
yield character


# TODO(#29): Return results simultaneously to prevent bias from generation speed.
def get_responses(user_prompt, category, source_lang, target_lang):
if not category:
raise gr.Error("Please select a category.")

if category == Category.TRANSLATE.value and (not source_lang or
not target_lang):
raise gr.Error("Please select source and target languages.")

models = sample(SUPPORTED_MODELS, 2)
instruction = get_instruction(category, source_lang, target_lang)

generators = []
for model in models:
try:
# TODO(#1): Allow user to set configuration.
response = completion(model=model,
messages=[{
"content": instruction,
"role": "system"
}, {
"content": user_prompt,
"role": "user"
}],
stream=True)
generators.append(response_generator(response))

# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in bot_response: {e}")
raise e

responses = ["", ""]

# It simulates concurrent response generation from two models.
while True:
stop = True

for i in range(len(generators)):
try:
yielded = next(generators[i])

if yielded is None:
continue

responses[i] += yielded
stop = False

yield responses + models + [instruction]

except StopIteration:
pass

# TODO(#1): Narrow down the exception type.
except Exception as e: # pylint: disable=broad-except
print(f"Error in generator: {e}")
raise e

if stop:
break

0 comments on commit 34c6b56

Please sign in to comment.