Skip to content

Commit

Permalink
chore: improve chat memory code (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaa006bd authored Sep 11, 2023
1 parent 240429a commit 214c449
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions ai/question.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,21 @@
def _cleanup_chats():
for chat_id in chatMemories.keys():
# drop chats that are older than 5 minutes
if time.time() - chatMemories[chat_id]["lastQuestion"] > 5 * 60:
del chatMemories[chat_id]
if time.time() - chatMemories.get(chat_id,{}).get("lastQuestionTime") > 5 * 60:
chatMemories.pop(chat_id)


def _get_chat(chat_id: str):
def _get_chat_memory(chat_id: str)->ConversationBufferMemory:
_cleanup_chats()

if chat_id not in chatMemories:
chatMemories[chat_id] = {

chatMemories.setdefault(chat_id, {
"memory": ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer"
),
"lastQuestion": time.time(),
}

chatMemories[chat_id]["lastQuestion"] = time.time()

return chatMemories[chat_id]["memory"]
"lastQuestionTime": time.time(),
}).set('lastQuestionTime',time.time())

return chatMemories.get(chat_id,{}).get("memory")


def ask(repo_id: int, chat_id: str, question: str):
Expand All @@ -46,27 +43,24 @@ def ask(repo_id: int, chat_id: str, question: str):
db = FAISS.load_local(os.path.join(repo_path, "vector_store"), embeddings)

retriever = db.as_retriever()
end = time.time()

retriever.search_kwargs["distance_metric"] = "cos"
retriever.search_kwargs["fetch_k"] = 100
retriever.search_kwargs["maximal_marginal_relevance"] = True
retriever.search_kwargs["k"] = 20

memory = _get_chat(chat_id)
memory = _get_chat_memory(chat_id)

qa = ConversationalRetrievalChain.from_llm(
llm=ChatOpenAI(temperature=0),
memory=memory,
retriever=retriever,
return_source_documents=True,
)
end = time.time()

result = qa(question)
print(f"Answer: {result['answer']}")
print(f"Sources: {[x.metadata['source'] for x in result['source_documents']]}")
end = time.time()

return result["answer"], [x.metadata["source"] for x in result["source_documents"]]

Expand Down

0 comments on commit 214c449

Please sign in to comment.