From d6917281231f7f690f26dd36ee52eb7b22fe25dd Mon Sep 17 00:00:00 2001 From: su77ungr <69374354+su77ungr@users.noreply.github.com> Date: Sun, 14 May 2023 14:18:23 +0200 Subject: [PATCH] Revert "Stable GUI (#29)" This reverts commit a6a9dc435a1bb74b822724a601212b22bc6e2051. --- gui.py | 34 +++++++------------ startLLM.py | 98 ++++++++++++++++++++++++++++------------------------- 2 files changed, 64 insertions(+), 68 deletions(-) diff --git a/gui.py b/gui.py index 3918dd0..fd3e2a0 100644 --- a/gui.py +++ b/gui.py @@ -17,10 +17,9 @@ model_stop = os.environ.get('MODEL_STOP') # Initialization -if "initialized" not in st.session_state: +if "input" not in st.session_state: st.session_state.input = "" st.session_state.running = False - st.session_state.initialized = False st.set_page_config(page_title="CASALIOY") @@ -50,9 +49,10 @@ colored_header(label='', description='', color_name='blue-30') response_container = st.container() + + def generate_response(input=""): - print("Input:"+input) with response_container: col1, col2, col3 = st.columns(3) with col1: @@ -68,32 +68,24 @@ def generate_response(input=""): os.environ["MODEL_STOP"] = str(st.session_state.stops_input) dotenv.set_key(dotenv_file, "MODEL_STOP", os.environ["MODEL_STOP"]) #with st.form("my_form", clear_on_submit=True): - if 'generated' in st.session_state: + if st.session_state['generated']: for i in range(len(st.session_state['generated'])): message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') message(st.session_state["generated"][i], key=str(i)) if input.strip() != "": st.session_state.running=True - st.session_state.past.append(input) + st.session_state.past.append(st.session_state.input) if st.session_state.running: - message(input, is_user=True) + message(st.session_state.input, is_user=True) message("Loading response. Please wait for me to finish before refreshing the page...", key="rmessage") #startLLM.qdrant = None #Not sure why this fixes db error - if st.session_state.initialized == False: - st.session_state.initialized = True - print("Initializing...") - startLLM.initialize_qa_system() - else: - print("Already initialized!") - response=startLLM.qa_system(st.session_state.input) - st.session_state.input = "" - answer, docs = response['result'], response['source_documents'] - st.session_state.generated.append(answer) - message(answer) + response = startLLM.main(st.session_state.input, True) + st.session_state.generated.append(response) + message(response) st.session_state.running = False - with form: - st.text_input("You: ", "", key="input", disabled=st.session_state.running) -form = st.form(key="input-form", clear_on_submit=True) -with form: + st.text_input("You: ", "", key="input", disabled=st.session_state.running) + + +with st.form("my_form", clear_on_submit=True): st.form_submit_button('SUBMIT', on_click=generate_response(st.session_state.input), disabled=st.session_state.running) \ No newline at end of file diff --git a/startLLM.py b/startLLM.py index 14baa24..503a1b2 100644 --- a/startLLM.py +++ b/startLLM.py @@ -18,57 +18,61 @@ qa_system=None def initialize_qa_system(): - global qa_system - if qa_system is None: - # Load stored vectorstore - llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx) - # Load ggml-formatted model - local_path = model_path + # Load stored vectorstore + llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx) + # Load ggml-formatted model + local_path = model_path + + client = qdrant_client.QdrantClient( + path=persist_directory, prefer_grpc=True + ) + qdrant = Qdrant( + client=client, collection_name="test", + embeddings=llama + ) - client = qdrant_client.QdrantClient( - path=persist_directory, prefer_grpc=True - ) - qdrant = Qdrant( - client=client, collection_name="test", - embeddings=llama - ) - - # Prepare the LLM chain - callbacks = [StreamingStdOutCallbackHandler()] - match model_type: - case "LlamaCpp": - from langchain.llms import LlamaCpp - llm = LlamaCpp(model_path=local_path, n_ctx=model_n_ctx, temperature=model_temp, stop=model_stop, callbacks=callbacks, verbose=True) - case "GPT4All": - from langchain.llms import GPT4All - llm = GPT4All(model=local_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=True, backend='gptj') - case _default: - print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.") - qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True) - qa_system = qa + # Prepare the LLM chain + callbacks = [StreamingStdOutCallbackHandler()] + match model_type: + case "LlamaCpp": + from langchain.llms import LlamaCpp + llm = LlamaCpp(model_path=local_path, n_ctx=model_n_ctx, temperature=model_temp, stop=model_stop, callbacks=callbacks, verbose=True) + case "GPT4All": + from langchain.llms import GPT4All + llm = GPT4All(model=local_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=True, backend='gptj') + case _default: + print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.") + qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True) + return qa -def main(): - initialize_qa_system() +def main(prompt="", gui=False): + global qa_system + if qa_system is None: + qa_system = initialize_qa_system() # Interactive questions and answers - while True: - query = input("\nEnter a query: ") - if query == "exit": - break - - # Get the answer from the chain - res = qa_system(query) - answer, docs = res['result'], res['source_documents'] + if (prompt.strip() != "" and gui) or gui==False: + while True: + query = prompt if gui else input("\nEnter a query: ") + if query == "exit": + break + + # Get the answer from the chain + res = qa_system(query) + answer, docs = res['result'], res['source_documents'] - # Print the result - print("\n\n> Question:") - print(query) - print("\n> Answer:") - print(answer) - - # Print the relevant sources used for the answer - for document in docs: - print("\n> " + document.metadata["source"] + ":") - print(document.page_content) + # Print the result + print("\n\n> Question:") + print(query) + print("\n> Answer:") + print(answer) + + # Print the relevant sources used for the answer + for document in docs: + print("\n> " + document.metadata["source"] + ":") + print(document.page_content) + + if gui: + return answer if __name__ == "__main__": main()