-
Notifications
You must be signed in to change notification settings - Fork 7
/
csv_langchain_streamlit_chatbot.py
78 lines (55 loc) · 2.73 KB
/
csv_langchain_streamlit_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
from pathlib import Path
from streamlit_chat import message
from langchain.document_loaders import CSVLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
import os
st.title('Langchain ChatBot for Your Custom CSV DATA POWERED BY OPEN AI')
OPENAI_API_KEY = st.text_input('Enter your Open AI API KEY', 'API KEY')
DESTINATION_DIR= st.text_input('Enter your destination file dir', 'DIR')
if st.button("Set API KEY"):
st.write("Your OPEN AI API KEY IS" , OPENAI_API_KEY)
os.environ["OPENAI_API_KEY"]=OPENAI_API_KEY
csv_file_uploaded = st.file_uploader(label="Upload your CSV File here")
if csv_file_uploaded is not None:
def save_file_to_folder(uploadedFile):
# Save uploaded file to 'content' folder.
save_folder = DESTINATION_DIR
save_path = Path(save_folder, uploadedFile.name)
with open(save_path, mode='wb') as w:
w.write(uploadedFile.getvalue())
if save_path.exists():
st.success(f'File {uploadedFile.name} is successfully saved!')
save_file_to_folder(csv_file_uploaded)
loader = CSVLoader(file_path=os.path.join(DESTINATION_DIR, csv_file_uploaded.name))
# Create an index using the loaded documents
index_creator = VectorstoreIndexCreator()
docsearch = index_creator.from_loaders([loader])
# Create a question-answering chain using the index
chain = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=docsearch.vectorstore.as_retriever(), input_key="question")
#Creating the chatbot interface
st.title("Chat with your CSV Data")
# Storing the chat
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
def generate_response(user_query):
response = chain({"question": user_query})
return response['result']
# We will get the user's input by calling the get_text function
def get_text():
input_text = st.text_input("You: ","Ask Question From your Document?", key="input")
return input_text
user_input = get_text()
if user_input:
output = generate_response(user_input)
# store the output
st.session_state.past.append(user_input)
st.session_state.generated.append(output)
if st.session_state['generated']:
for i in range(len(st.session_state['generated'])-1, -1, -1):
message(st.session_state["generated"][i], key=str(i))
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')