From 214c449c8fa41735e5cce6054a09cd398450597c Mon Sep 17 00:00:00 2001 From: Abdullah Al Amin Date: Mon, 11 Sep 2023 12:35:16 +0200 Subject: [PATCH] chore: improve chat memory code (#38) --- ai/question.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/ai/question.py b/ai/question.py index 460b325..562116c 100644 --- a/ai/question.py +++ b/ai/question.py @@ -18,24 +18,21 @@ def _cleanup_chats(): for chat_id in chatMemories.keys(): # drop chats that are older than 5 minutes - if time.time() - chatMemories[chat_id]["lastQuestion"] > 5 * 60: - del chatMemories[chat_id] + if time.time() - chatMemories.get(chat_id,{}).get("lastQuestionTime") > 5 * 60: + chatMemories.pop(chat_id) -def _get_chat(chat_id: str): +def _get_chat_memory(chat_id: str)->ConversationBufferMemory: _cleanup_chats() - - if chat_id not in chatMemories: - chatMemories[chat_id] = { + + chatMemories.setdefault(chat_id, { "memory": ConversationBufferMemory( memory_key="chat_history", return_messages=True, output_key="answer" ), - "lastQuestion": time.time(), - } - - chatMemories[chat_id]["lastQuestion"] = time.time() - - return chatMemories[chat_id]["memory"] + "lastQuestionTime": time.time(), + }).set('lastQuestionTime',time.time()) + + return chatMemories.get(chat_id,{}).get("memory") def ask(repo_id: int, chat_id: str, question: str): @@ -46,14 +43,13 @@ def ask(repo_id: int, chat_id: str, question: str): db = FAISS.load_local(os.path.join(repo_path, "vector_store"), embeddings) retriever = db.as_retriever() - end = time.time() retriever.search_kwargs["distance_metric"] = "cos" retriever.search_kwargs["fetch_k"] = 100 retriever.search_kwargs["maximal_marginal_relevance"] = True retriever.search_kwargs["k"] = 20 - memory = _get_chat(chat_id) + memory = _get_chat_memory(chat_id) qa = ConversationalRetrievalChain.from_llm( llm=ChatOpenAI(temperature=0), @@ -61,12 +57,10 @@ def ask(repo_id: int, chat_id: str, question: str): retriever=retriever, return_source_documents=True, ) - end = time.time() result = qa(question) print(f"Answer: {result['answer']}") print(f"Sources: {[x.metadata['source'] for x in result['source_documents']]}") - end = time.time() return result["answer"], [x.metadata["source"] for x in result["source_documents"]]