Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable GUI #29

Merged
merged 1 commit into from
May 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
model_stop = os.environ.get('MODEL_STOP')

# Initialization
if "input" not in st.session_state:
if "initialized" not in st.session_state:
st.session_state.input = ""
st.session_state.running = False
st.session_state.initialized = False

st.set_page_config(page_title="CASALIOY")

Expand Down Expand Up @@ -49,10 +50,9 @@

colored_header(label='', description='', color_name='blue-30')
response_container = st.container()



def generate_response(input=""):
print("Input:"+input)
with response_container:
col1, col2, col3 = st.columns(3)
with col1:
Expand All @@ -68,24 +68,32 @@ def generate_response(input=""):
os.environ["MODEL_STOP"] = str(st.session_state.stops_input)
dotenv.set_key(dotenv_file, "MODEL_STOP", os.environ["MODEL_STOP"])
#with st.form("my_form", clear_on_submit=True):
if st.session_state['generated']:
if 'generated' in st.session_state:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))
if input.strip() != "":
st.session_state.running=True
st.session_state.past.append(st.session_state.input)
st.session_state.past.append(input)
if st.session_state.running:
message(st.session_state.input, is_user=True)
message(input, is_user=True)
message("Loading response. Please wait for me to finish before refreshing the page...", key="rmessage")
#startLLM.qdrant = None #Not sure why this fixes db error
response = startLLM.main(st.session_state.input, True)
st.session_state.generated.append(response)
message(response)
if st.session_state.initialized == False:
st.session_state.initialized = True
print("Initializing...")
startLLM.initialize_qa_system()
else:
print("Already initialized!")
response=startLLM.qa_system(st.session_state.input)
st.session_state.input = ""
answer, docs = response['result'], response['source_documents']
st.session_state.generated.append(answer)
message(answer)
st.session_state.running = False
st.text_input("You: ", "", key="input", disabled=st.session_state.running)


with st.form("my_form", clear_on_submit=True):
with form:
st.text_input("You: ", "", key="input", disabled=st.session_state.running)
form = st.form(key="input-form", clear_on_submit=True)
with form:
st.form_submit_button('SUBMIT', on_click=generate_response(st.session_state.input), disabled=st.session_state.running)

98 changes: 47 additions & 51 deletions startLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,57 @@
qa_system=None

def initialize_qa_system():
# Load stored vectorstore
llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx)
# Load ggml-formatted model
local_path = model_path

client = qdrant_client.QdrantClient(
path=persist_directory, prefer_grpc=True
)
qdrant = Qdrant(
client=client, collection_name="test",
embeddings=llama
)

# Prepare the LLM chain
callbacks = [StreamingStdOutCallbackHandler()]
match model_type:
case "LlamaCpp":
from langchain.llms import LlamaCpp
llm = LlamaCpp(model_path=local_path, n_ctx=model_n_ctx, temperature=model_temp, stop=model_stop, callbacks=callbacks, verbose=True)
case "GPT4All":
from langchain.llms import GPT4All
llm = GPT4All(model=local_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=True, backend='gptj')
case _default:
print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.")
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True)
return qa

def main(prompt="", gui=False):
global qa_system
if qa_system is None:
qa_system = initialize_qa_system()
# Load stored vectorstore
llama = LlamaCppEmbeddings(model_path=llama_embeddings_model, n_ctx=model_n_ctx)
# Load ggml-formatted model
local_path = model_path

client = qdrant_client.QdrantClient(
path=persist_directory, prefer_grpc=True
)
qdrant = Qdrant(
client=client, collection_name="test",
embeddings=llama
)

# Prepare the LLM chain
callbacks = [StreamingStdOutCallbackHandler()]
match model_type:
case "LlamaCpp":
from langchain.llms import LlamaCpp
llm = LlamaCpp(model_path=local_path, n_ctx=model_n_ctx, temperature=model_temp, stop=model_stop, callbacks=callbacks, verbose=True)
case "GPT4All":
from langchain.llms import GPT4All
llm = GPT4All(model=local_path, n_ctx=model_n_ctx, callbacks=callbacks, verbose=True, backend='gptj')
case _default:
print("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.")
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=qdrant.as_retriever(search_type="mmr"), return_source_documents=True)
qa_system = qa

def main():
initialize_qa_system()
# Interactive questions and answers
if (prompt.strip() != "" and gui) or gui==False:
while True:
query = prompt if gui else input("\nEnter a query: ")
if query == "exit":
break

# Get the answer from the chain
res = qa_system(query)
answer, docs = res['result'], res['source_documents']
while True:
query = input("\nEnter a query: ")
if query == "exit":
break

# Get the answer from the chain
res = qa_system(query)
answer, docs = res['result'], res['source_documents']

# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)

# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)

if gui:
return answer
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)

# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)

if __name__ == "__main__":
main()