Skip to content

Commit

Permalink
Stable GUI (#29)
Browse files Browse the repository at this point in the history
+ Way better qa_system initialization.
+ Tweaked UI
  • Loading branch information
alxspiker authored May 14, 2023
1 parent 1909aa2 commit a6a9dc4
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 64 deletions.
34 changes: 21 additions & 13 deletions gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
model_stop = os.environ.get('MODEL_STOP')

# Initialization
if "input" not in st.session_state:
if "initialized" not in st.session_state:
st.session_state.input = ""
st.session_state.running = False
st.session_state.initialized = False

st.set_page_config(page_title="CASALIOY")

Expand Down Expand Up @@ -49,10 +50,9 @@

colored_header(label='', description='', color_name='blue-30')
response_container = st.container()



def generate_response(input=""):
print("Input:"+input)
with response_container:
col1, col2, col3 = st.columns(3)
with col1:
Expand All @@ -68,24 +68,32 @@ def generate_response(input=""):
os.environ["MODEL_STOP"] = str(st.session_state.stops_input)
dotenv.set_key(dotenv_file, "MODEL_STOP", os.environ["MODEL_STOP"])
#with st.form("my_form", clear_on_submit=True):
if st.session_state['generated']:
if 'generated' in st.session_state:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))
if input.strip() != "":
st.session_state.running=True
st.session_state.past.append(st.session_state.input)
st.session_state.past.append(input)
if st.session_state.running:
message(st.session_state.input, is_user=True)
message(input, is_user=True)
message("Loading response. Please wait for me to finish before refreshing the page...", key="rmessage")
#startLLM.qdrant = None #Not sure why this fixes db error
response = startLLM.main(st.session_state.input, True)
st.session_state.generated.append(response)
message(response)
if st.session_state.initialized == False:
st.session_state.initialized = True
print("Initializing...")
startLLM.initialize_qa_system()
else:
print("Already initialized!")
response=startLLM.qa_system(st.session_state.input)
st.session_state.input = ""
answer, docs = response['result'], response['source_documents']
st.session_state.generated.append(answer)
message(answer)
st.session_state.running = False
st.text_input("You: ", "", key="input", disabled=st.session_state.running)


with st.form("my_form", clear_on_submit=True):
with form:
st.text_input("You: ", "", key="input", disabled=st.session_state.running)
form = st.form(key="input-form", clear_on_submit=True)
with form:
st.form_submit_button('SUBMIT', on_click=generate_response(st.session_state.input), disabled=st.session_state.running)

98 changes: 47 additions & 51 deletions startLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,57 @@
qa_system=None

def initialize_qa_system():
# Load stored vectorstore
llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx)
# Load ggml-formatted model
local_path = model_path

client = qdrant_client.QdrantClient(
path=persist_directory, prefer_grpc=True
)
qdrant = Qdrant(
client=client, collection_name="test",
embeddings=llama
)

# Prepare the LLM chain
callbacks = [StreamingStdOutCallbackHandler()]
match model_type:
case "LlamaCpp":
from langchain.llms import LlamaCpp
llm = LlamaCpp(model_path=local_path, n_ctx=model_n_ctx, temperature=model_temp, stop=model_stop, callbacks=callbacks, verbose=True)
case "GPT4All":
from langchain.llms import GPT4All
llm = GPT4All(model=local_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=True, backend='gptj')
case _default:
print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.")
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True)
return qa

def main(prompt="", gui=False):
global qa_system
if qa_system is None:
qa_system = initialize_qa_system()
# Load stored vectorstore
llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx)
# Load ggml-formatted model
local_path = model_path

client = qdrant_client.QdrantClient(
path=persist_directory, prefer_grpc=True
)
qdrant = Qdrant(
client=client, collection_name="test",
embeddings=llama
)

# Prepare the LLM chain
callbacks = [StreamingStdOutCallbackHandler()]
match model_type:
case "LlamaCpp":
from langchain.llms import LlamaCpp
llm = LlamaCpp(model_path=local_path, n_ctx=model_n_ctx, temperature=model_temp, stop=model_stop, callbacks=callbacks, verbose=True)
case "GPT4All":
from langchain.llms import GPT4All
llm = GPT4All(model=local_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=True, backend='gptj')
case _default:
print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.")
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True)
qa_system = qa

def main():
initialize_qa_system()
# Interactive questions and answers
if (prompt.strip() != "" and gui) or gui==False:
while True:
query = prompt if gui else input("\nEnter a query: ")
if query == "exit":
break

# Get the answer from the chain
res = qa_system(query)
answer, docs = res['result'], res['source_documents']
while True:
query = input("\nEnter a query: ")
if query == "exit":
break

# Get the answer from the chain
res = qa_system(query)
answer, docs = res['result'], res['source_documents']

# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)

# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)

if gui:
return answer
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)

# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)

if __name__ == "__main__":
main()

0 comments on commit a6a9dc4

Please sign in to comment.