This repository has been archived by the owner on Sep 2, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
81 lines (64 loc) · 2.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Main entrypoint for the app."""
import logging
import pickle
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.templating import Jinja2Templates
from langchain.vectorstores import VectorStore
from callback import QuestionGenCallbackHandler, StreamingLLMCallbackHandler
from query_data import get_chain
from schemas import ChatResponse
app = FastAPI()
templates = Jinja2Templates(directory="templates")
vectorstore: Optional[VectorStore] = None
@app.on_event("startup")
async def startup_event():
logging.info("loading vectorstore")
if not Path("vectorstore.pkl").exists():
raise ValueError("vectorstore.pkl does not exist, please run ingest.py first")
with open("vectorstore.pkl", "rb") as f:
global vectorstore
vectorstore = pickle.load(f)
@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.websocket("/chat")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
question_handler = QuestionGenCallbackHandler(websocket)
stream_handler = StreamingLLMCallbackHandler(websocket)
chat_history = []
qa_chain = get_chain(vectorstore, question_handler, stream_handler)
# Use the below line instead of the above line to enable tracing
# Ensure `langchain-server` is running
# qa_chain = await get_chain(vectorstore, question_handler, stream_handler, tracing=True)
while True:
try:
# Receive and send back the client message
question = await websocket.receive_text()
resp = ChatResponse(sender="you", message=question, type="stream")
await websocket.send_json(resp.dict())
# Construct a response
start_resp = ChatResponse(sender="bot", message="", type="start")
await websocket.send_json(start_resp.dict())
result = await qa_chain.acall(
{"question": question, "chat_history": chat_history}
)
chat_history.append((question, result["answer"]))
end_resp = ChatResponse(sender="bot", message="", type="end")
await websocket.send_json(end_resp.dict())
except WebSocketDisconnect:
logging.info("websocket disconnect")
break
except Exception as e:
logging.error(e)
resp = ChatResponse(
sender="bot",
message="Sorry, something went wrong. Try again.",
type="error",
)
await websocket.send_json(resp.dict())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9000)