-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
611 lines (519 loc) · 27.6 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
import openai
import streamlit as st
import json
from yaml.loader import SafeLoader
import streamlit_authenticator as stauth
from streamlit_feedback import streamlit_feedback
import yaml
import numpy as np
import pandas as pd
from pathlib import Path
from streamlit_extras.app_logo import add_logo
from trubrics.integrations.streamlit import FeedbackCollector
from search import search_bing
import uuid
import datetime as dt
import humanize
import traceback
from langchain.embeddings.openai import OpenAIEmbeddings
UTC_TIMESTAMP = int(dt.datetime.utcnow().timestamp())
from pymongo import MongoClient
from bson import ObjectId
import re
TIMEOUT = 60
CACHE_SIMILARITY_THRESHOLD = 0.92 # Found by experimentation
RETRIEVAL_PROMPT = """
You are a powerful AI chat assistant that can answer user questions by retrieving relevant information from various sources. If you call any functions, please follow strictly the function descriptions and infer the parameters from the predefined ones based on the message history until this point. Do not make up your own function call parameters that are not defined.
"""
GENERATION_PROMPT = """
You are a powerful AI chat assistant that can answer user questions by retrieving relevant information from various sources. Be careful to answer the question using only the information from function calls. If they do not return any answers or the answers don't match the question, just say you cannot answer the question and stop there.
If you used one or several retrieved information sources in your answer, please cite the relevant sources at the end of your response, starting with the text "SOURCES: " (always in plural), followed by a JSON standard formatted structure, as in below sample:
SOURCES: {
"sources": [
"source/url 1",
"source/url 2 etc"
]
}
"""
# Embedding model
embeddings = OpenAIEmbeddings(openai_api_key=st.secrets["OPENAI_API_KEY"])
st.set_page_config(
layout="wide",
# initial_sidebar_state="expanded",
page_title="PR🌍MPTERRA by ECOGNIZE",
page_icon="🌍",
menu_items={
"Get Help": "https://www.github.com/nathanyaqueby/ecognize/",
"Report a bug": "https://www.github.com/nathanyaqueby/ecognize/issues",
"About": "Learn to adjust your calls to help the planet!",
},
)
with open("style.css", "r") as css:
st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
add_logo("ecognize logo.png", height=100)
############ functions
def _submit_feedback(user_response, emoji=None):
st.toast(f"Feedback submitted: {user_response}", icon=emoji)
return user_response.update({"some metadata": 123})
def load_user_points(username):
if user_points := st.session_state.get('user_points', {}).get(username, None):
return user_points
# load the csv file with the user points
user_points_pd = pd.read_csv("user_points.csv")
# get the points for the user
points = user_points_pd[user_points_pd["username"] == username]["user_points"].values[0]
return points
# Replace the following with your own connection string
MONGO_URI = st.secrets["MONGO_URI"]
# Connect to the MongoDB Atlas cluster
client = MongoClient(MONGO_URI)
# Select your database
db = client["junction"]
# Select your collection
collection = db["sustainability_scores"]
embeddings_collection = db["embeddings"]
def update_user(username, user_point, sustainability_score):
""" Update the user's points and sustainability score based on username """
user_document = collection.find_one({"username": username})
new_user_point = user_document["user_point"] + user_point
new_sustainability_score = user_document["sustainability_score"] + sustainability_score
if user_document:
result = collection.update_one(
{"_id": user_document["_id"]},
{"$set": {"user_point": new_user_point, "sustainability_score": new_sustainability_score}}
)
if result.matched_count > 0:
print(f"User {username} updated. Points: {new_user_point}, Sustainability Score: {new_sustainability_score}")
else:
print(f"Update operation did not find the user {username}")
else:
print(f"No user found with the username {username}")
def initialize_cache():
# LOCAL VERSION: load "cache.csv" file into pandas DataFrame if it exists, otherwise create a new one
if Path("cache.csv").is_file():
return pd.read_csv("cache.csv")
else:
return pd.DataFrame(columns=["query", "embedding", "answer", "expires_at"])
def add_to_cache(query, embedding, answer, expires_at):
pass
# Callback function for refresh button
def refresh_metrics():
average_points, average_query, user_points, user_num_query = load_all_from_mongo(username)
st.session_state['metrics'] = (average_points, average_query, user_points, user_num_query)
def load_from_mongo(username):
""" Fetch a single document based on username """
query = {"username": username}
document = collection.find_one(query)
return document
def load_all_from_mongo(username):
# compute the average points of all users
user_pd = load_from_mongo(username)
user_points = user_pd["user_point"]
user_num_query = user_pd["sustainability_score"]
# load all users nqueby, tianyi, cmakafui, angelineov, outokumpu from mongo
users = ["nqueby", "tianyi", "cmakafui", "angelineov", "outokumpu"]
sustainability_scores = []
user_total_points = 0
for user in users:
user_document = load_from_mongo(user)
sustainability_scores.append(user_document["sustainability_score"])
user_total_points += user_document["user_point"]
average_points = user_total_points / len(users)
# get the user's number of query from the database
average_query = np.mean(sustainability_scores)
return average_points, average_query, user_points, user_num_query
def add_metrics(cola, colb, username):
average_points, average_query, user_points, user_num_query = st.session_state['metrics']
with cola:
# add a st.metric to show how much the user's points are above or less than the average in percentage
if user_points > average_points:
st.metric("Your points", f"{user_points} 🌍", f"{round((user_points - average_points) / average_points * 100)} %", help="Accumulate sustainability points by giving feedback to the LLM's responses or ask a question that is already saved in the cache.")
elif user_points < average_points:
st.metric("Your points", f"{user_points} 🌍", f"{round((user_points - average_points) / average_points * 100)} %", help="Accumulate sustainability points by giving feedback to the LLM's responses or ask a question that is already saved in the cache.")
else:
st.metric("Your points", f"{user_points} 🌍", f"Average", delta_color="off", help="Accumulate sustainability points by giving feedback to the LLM's responses or ask a question that is already saved in the cache.")
with colb:
# add a st.metric to show the user's number of query
if user_num_query > average_query:
st.metric("Eco-friendly queries", f"{user_num_query} 🌿", f"{round((user_num_query - average_query) / average_query * 100)} %", help="Accumulate sustainability points by giving feedback to the LLM's responses or ask a question that is already saved in the cache.")
elif user_num_query < average_query:
st.metric("Eco-friendly queries", f"{user_num_query} 🌿", f"{round((user_num_query - average_query) / average_query * 100)} %", help="Accumulate sustainability points by giving feedback to the LLM's responses or ask a question that is already saved in the cache.")
else:
st.metric("Eco-friendly queries", f"{user_num_query} 🌿", f"Average", delta_color="off", help="Accumulate sustainability points by giving feedback to the LLM's responses or ask a question that is already saved in the cache.")
# List of ambiguous words (this is just an example, you should expand this list)
ambiguous_words = set(["thing", "stuff", "various", "many", "often", "frequently"])
# Function to evaluate if a prompt contains ambiguous words
def has_ambiguous_words(prompt):
words = set(re.findall(r'\b\w+\b', prompt.lower()))
return any(word in ambiguous_words for word in words)
# Function to evaluate the prompt
def evaluate_prompt(prompt):
criteria = {
"uses_renewable_energy": True, # Example: determine based on user's location or choice
# "uses_smallest_model": st.session_state["openai_model"] == "gpt-3.5-turbo",
"good_prompt_length": len(prompt) < 500,
"no_ambiguous_words": not has_ambiguous_words(prompt),
"no_need_for_clarification": True # TODO: implement logic to determine this
}
# Logic to update other criteria goes here
return criteria
# Update the checklist based on the prompt
def update_checklist(prompt):
st.session_state['checklist'] = evaluate_prompt(prompt)
# Display the checklist in the sidebar
def display_checklist():
st.sidebar.title("Eco prompt checklist", help="This checklist is used to evaluate the sustainability of your prompt after you input it.")
for criteria, is_met in st.session_state['checklist'].items():
icon = "✅" if is_met else "⬜"
st.sidebar.write(f"{icon} {criteria.replace('_', ' ').capitalize()}")
############
############ Function calling
functions = [
{
"name": "search_bing",
"description": "Search for relevant hits from Bing.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": f"Query string to search the posts for, inferred from the user message and ALWAYS translated to ENGLISH before passing on to the function. Sample query styles: 'expected development of stainless steel market pricing in December 2022' or 'possible energy price developments over January-March 2023'. Note: today is {dt.datetime.utcfromtimestamp(UTC_TIMESTAMP).strftime('%Y-%m-%d')}. You can use this information in your query to absolute dates instead of relative ones."
},
"search_index": {
"type": "string",
"description": "Name of the Bing Search index to use. Valid choices: 'DEFAULT'. If 'DEFAULT' is used, the search will be performed on the entire web."
}
},
"required": ["query", "search_index"]
}
}
]
available_functions = {
"search_bing": search_bing,
}
# st.title("PR🌍MPTERRA")
# put logo in the center
col1, col2, col3 = st.columns([1, 6, 1])
with col1:
st.write("")
with col2:
st.image("ecognize logo.png", use_column_width=True)
with col3:
st.write("")
collector = FeedbackCollector(
project="ecognize",
email=st.secrets.TRUBRICS_EMAIL,
password=st.secrets.TRUBRICS_PASSWORD,
)
with open("config.yaml") as file:
config = yaml.load(file, Loader=SafeLoader)
with st.spinner(text="In progress..."):
authenticator = stauth.Authenticate(
config["credentials"],
config["cookie"]["name"],
config["cookie"]["key"],
config["cookie"]["expiry_days"],
config["preauthorized"],
)
name, authentication_status, username = authenticator.login("Login", "main")
# write a welcome message after the user logs in
if authentication_status:
# user_points = load_user_points(username)
# add st.info to explain the app
st.info("For demo purposes, we set the cache expiration time to 5 minutes so you can test the difference between a cached and a non-cached query instantly.", icon="ℹ️")
st.sidebar.title(f"Hello, {name.split()[0]}!")
# Initialize session state for metrics
if 'metrics' not in st.session_state:
st.session_state['metrics'] = load_all_from_mongo(username)
# create two cols
col41, col42 = st.sidebar.columns(2)
add_metrics(col41, col42, username)
# add refresh button to reload the mongo db
st.sidebar.button("Refresh points", type="primary", on_click=refresh_metrics(), use_container_width=True)
# st.sidebar.markdown(f"""
# <p style='font-family': Garet'>Hello, {name.split()[0]}! <p> <br>
# <p style='font-family': Garet'>Your points: {user_points}</p>
# """, unsafe_allow_html=True)
# rewrite st info with html font family Garet
st.markdown("""
<p style='font-family': Garet'>Let's take care of our only TERRA, one PROMPT at the time with team <b>ECOGNIZE</b> 🌍 We offer a platform that trains users to write prompts in a more sustainable way! To get started, type a question or message in the chatbox and click enter. You can provide feedback on the response or ask a cached prompt to gain points.</p>
""", unsafe_allow_html=True)
feedback = None
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "system", "content": RETRIEVAL_PROMPT}]
if "session_id" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4())
if "feedback" not in st.session_state:
st.session_state.feedback = {}
if "cache" not in st.session_state:
st.session_state.cache = initialize_cache()
# Initialize session state for the checklist
if 'checklist' not in st.session_state:
st.session_state['checklist'] = {
"uses_renewable_energy": True,
# "uses_smallest_model": False,
"good_prompt_length": False,
"no_ambiguous_words": False,
"no_need_for_clarification": False
}
feedback_kwargs = {
"feedback_type": "thumbs",
"optional_text_label": "Please provide extra information",
"on_submit": _submit_feedback,
}
# Assign IDs to existing messages if they don't have one
for n, msg in enumerate(st.session_state["messages"]):
if msg["role"] == "system":
continue
contents = msg["content"]
sources = ""
with st.chat_message(msg["role"], avatar="🌍" if msg["role"] == "assistant" else None):
if msg["role"] == "assistant":
if "SOURCES:" in contents:
contents, sources = contents.split("SOURCES:", 1)
# Clean up the sources string
sources = sources.strip()
if len(sources.split("}", 1)) == 2:
sources, contents_post = sources.split("}", 1)
sources += "}"
contents += f"\n{contents_post}"
st.markdown(contents)
if len(sources) > 0:
try:
sources = json.loads(sources)
other_sources = []
for source in sources["sources"]:
if source.startswith("http") and source.endswith(".mp4"):
st.video(source)
else:
other_sources.append(source)
with st.expander("Sources"):
# Turn those sources into download links which we have the file on
for source_file in other_sources:
if source_file.startswith("http"):
st.link_button(source_file, source_file)
else:
st.text(source_file)
except Exception as e:
st.warning(f"Error parsing sources {sources}: {e}")
# Display raw sources
with st.expander("Sources"):
st.markdown(sources)
if msg["role"] == "assistant":
if n > 0:
feedback_key = f"feedback_{int(n / 2)}"
if feedback_key not in st.session_state:
st.session_state[feedback_key] = None
feedback = collector.st_feedback(
component="default",
feedback_type="thumbs",
open_feedback_label="[Optional] Provide additional feedback",
model="gpt-4",
key=feedback_key,
user_id=st.secrets["TRUBRICS_EMAIL"]
)
else:
feedback_key = f"feedback_{n}"
if feedback_key not in st.session_state:
st.session_state[feedback_key] = None
feedback = collector.st_feedback(
component="default",
feedback_type="thumbs",
open_feedback_label="[Optional] Provide additional feedback",
model="gpt-4",
key=feedback_key,
prompt_id=st.session_state.prompt_ids[n],
user_id=st.secrets["TRUBRICS_EMAIL"]
)
if feedback:
st.session_state['feedback'][feedback_key] = feedback
# Assuming 1 point for each feedback
update_user(username, 2, 0)
# add a notification that the user has earned a point
st.sidebar.success(
f"You have earned +2 points for giving feedback!"
)
if "prompt" in st.session_state:
update_checklist(st.session_state["prompt"])
display_checklist()
# Save cache locally to CSV (if it has a length > 0)
if "cache" in st.session_state and len(st.session_state.cache) > 0:
with st.sidebar:
st.title("Cached queries", help="These are the queries that have been cached so far. You can ask them again to gain points because then we don't need to call the LLM!")
display_df = st.session_state.cache.copy()
# Count TTL as the difference between the expiration date and the current date
display_df["expires"] = display_df["expires_at"].apply(lambda x: humanize.naturaltime(dt.timedelta(seconds=UTC_TIMESTAMP - x)))
# Remove entries that have expired
display_df = display_df[display_df["expires_at"] > UTC_TIMESTAMP]
st.dataframe(display_df[["query", "expires"]], hide_index=True, use_container_width=True)
st.session_state.cache.to_csv("cache.csv", index=False)
if prompt := st.chat_input("Ask me anything"):
# Add the user message to the session state and render it
st.session_state['messages'].append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant", avatar="🌍"):
# For streaming, we need to loop through the response generator
reply_text = ""
function_name = ""
function_args = ""
for chunk in openai.ChatCompletion.create(
# model="gpt-35-turbo-16k",
deployment_id="gpt-35-turbo-16k",
messages=st.session_state['messages'],
max_tokens=1000,
timeout=TIMEOUT,
function_call={"name": "search_bing"},
functions=functions,
stream=True,
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", None)
function_call = delta.get("function_call", None)
if function_call is not None:
function_name += function_call.get("name", "")
function_args += function_call.get("arguments", "")
if content is not None:
reply_text += content
# Sanitize output
if reply_text.startswith("AI: "):
reply_text = reply_text.split("AI: ", 1)[1]
# Collect full function call
if function_name != "" and function_args != "":
function_call = {"name": function_name, "arguments": function_args}
else:
function_call = None
if function_call is None: # Not a function call, return normal message
# Sanitize
if reply_text.startswith("AI: "):
reply_text = reply_text.split("AI: ", 1)[1]
st.markdown(reply_text)
# Model wants to call a function and call it, if appropriate
else:
# Read the function call from model response and execute it
fun_name = function_call.get("name", None)
if fun_name is not None and fun_name and fun_name in available_functions:
function = available_functions[fun_name]
else:
function = None
fun_args = function_call.get("arguments", None)
if fun_args is not None and isinstance(fun_args, str):
fun_args = json.loads(fun_args)
query = fun_args.get("query", None)
# Check in the cache if the response is already there and if so just return the relevant answer
if query is not None:
# Vectorize the query
query_embedding = embeddings.embed_query(query)
temp_cache = st.session_state.cache.copy()
# Filter out expired cache entries
temp_cache = temp_cache[temp_cache["expires_at"] > UTC_TIMESTAMP]
# Update the cache
st.session_state.cache = temp_cache
# Loop through the cache and calculate the cosine similarity between the query embedding and each of the cached embeddings
try:
temp_cache["similarity"] = temp_cache["embedding"].apply(lambda x: np.dot(np.array(eval(x)), np.array(query_embedding)) / (np.linalg.norm(np.array(eval(x))) * np.linalg.norm(np.array(query_embedding))))
except: # x might already be a list
traceback.print_exc()
# Sort the cache by similarity, descending
temp_cache = temp_cache.sort_values(by=["similarity"], ascending=False)
# See if the top result is above a certain threshold
if len(temp_cache) > 0 and temp_cache.iloc[0]["similarity"] > CACHE_SIMILARITY_THRESHOLD:
# Directly add that answer as the chat response
reply_text = temp_cache.iloc[0]["answer"]
cached_query = temp_cache.iloc[0]["query"]
st.session_state['messages'].append({"role": "assistant", "content": f"`[We found a cached answer! 🌿 The cached question was: {cached_query} - Not the question you were asking? Please rate the answer below so our team can improve the experience. Thanks!]`\n\n{reply_text}"})
# add points to the user
update_user(username, 2, 1)
st.rerun()
if function is None:
fun_res = ["Error, no function specified"]
elif query is None:
fun_res = ["Error, no query specified"]
else:
with st.status(f"Called function `{fun_name}`"):
st.json(fun_args, expanded=True)
fun_res = function(fun_args)
# Build an abridged, temporary message to be fed into a more powerful GPT-4 model to limit the number of tokens
# Proceed to the secondary call to generate results
messages = [{"role": "system", "content": GENERATION_PROMPT}]
if query is not None:
messages.append({"role": "user", "content": query})
messages.extend([{"role": "function", "name": fun_name, "content": one_fun_res} for one_fun_res in fun_res])
# For streaming, we need to loop through the response generator
message_placeholder = st.empty()
reply_text = ""
for chunk in openai.ChatCompletion.create(
# model="gpt-4",
deployment_id="gpt-4",
messages=messages,
max_tokens=1500,
timeout=TIMEOUT,
# function_call="auto",
# functions=[],
stream=True,
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content", None)
function_call = delta.get("function_call", None)
if function_call is not None:
function_name += function_call.get("name", "")
function_args += function_call.get("arguments", "")
if content is not None:
reply_text += content
# Sanitize output
if reply_text.startswith("AI: "):
reply_text = reply_text.split("AI: ", 1)[1]
render_text = reply_text
if "SOURCES:" in render_text:
render_text, sources = render_text.split("SOURCES:", 1)
# Continuously write the response in Streamlit
message_placeholder.markdown(render_text)
# # Collect full function call
# if function_name != "" and function_args != "":
# function_call = {"name": function_name, "arguments": function_args}
# else:
# function_call = None
# Sanitize
if reply_text.startswith("AI: "):
reply_text = reply_text.split("AI: ", 1)[1]
render_text = reply_text
if "SOURCES:" in render_text:
render_text, sources = render_text.split("SOURCES:", 1)
message_placeholder.markdown(render_text)
logged_prompt = collector.log_prompt(
config_model={"model": "gpt-4"},
prompt=prompt,
generation=reply_text,
session_id=st.session_state.session_id,
# tags=tags,
user_id=str(st.secrets["TRUBRICS_EMAIL"])
)
# Add the query, its embedding, the answer (reply_text) and the expiration date to the cache. Expires in 1 day (5 minutes for testing)
st.session_state.cache = pd.concat([
st.session_state.cache,
pd.DataFrame([
{
"query": query,
"embedding": json.dumps(query_embedding),
"answer": reply_text,
"expires_at": UTC_TIMESTAMP + 5 * 60
}
])
])
# st.session_state.prompt_ids.append(logged_prompt.id)
# After getting the response, add it to the session state
st.session_state['messages'].append({"role": "assistant", "content": reply_text})
st.rerun()
# should be the end of the sidebar
with st.sidebar:
authenticator.logout("Logout", "main", key="unique_key")
elif st.session_state["authentication_status"] is False:
st.error("Username/password is incorrect")
elif st.session_state["authentication_status"] is None:
st.info("Please enter your username and password. Use the default username and password (outokumpu: junction2023) if you are not registered yet!")
# if feedback:
# st.write(feedback)