-
Notifications
You must be signed in to change notification settings - Fork 0
/
chatomatic.py
100 lines (82 loc) · 3.96 KB
/
chatomatic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Case is not important in the questions
from qa_database import *
import random
from rank_bm25 import *
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModel
import torch
sent_sim_model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b')
class Chatomatic:
random_generator = random.Random()
cache = {}
qa_databases = {}
def load_from_dataset(self, file_path, qa_database):
if file_path.endswith('.yml') or file_path.endswith('.yaml'):
qa_database.load_from_yaml(file_path)
elif file_path.endswith('.json'):
qa_database.load_from_json(file_path)
def __init__(self, file_path, language = "en"):
self.qa_databases[language] = QADatabase()
self.load_from_dataset(file_path, self.qa_databases[language])
def add_dataset(self, file_path, language = "en"):
if not language in self.qa_databases:
self.qa_databases[language] = QADatabase()
self.load_from_dataset(file_path, self.qa_databases[language])
def find_answer_to_question(self, question, language = "en"):
qa_database = self.qa_databases[language]
for qa in qa_database.questions:
if qa.title == question:
return self.random_generator.choice(qa.answers)
return None
def find_most_similar_question_transformers(self, question, language = "en"):
qa_database = self.qa_databases[language]
corpus = [doc.title for doc in qa_database.questions]
corpus_embeddings = sent_sim_model.encode(corpus, convert_to_tensor = True)
sentence_embedding = sent_sim_model.encode(question, convert_to_tensor = True)
cos_scores = util.pytorch_cos_sim(sentence_embedding, corpus_embeddings)[0]
top_result = np.argpartition(-cos_scores, range(1))[0]
return qa_database.questions[top_result]
def find_most_similar_question_bm25(self, question, language = "en"):
qa_database = self.qa_databases[language]
tokenized_corpus = [doc.title.split(" ") for doc in qa_database.questions]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = question.split(" ")
doc_scores = bm25.get_scores(tokenized_query)
print(doc_scores)
doc_scores = list(doc_scores)
return qa_database.questions[doc_scores.index(max(doc_scores))]
def find_most_similar_question(self, question, language = "en", method = "transformers"): # bm25 / transformers
# From https://www.analyticsvidhya.com/blog/2021/05/build-your-own-nlp-based-search-engine-using-bm25/
if not language in self.cache:
self.cache[language] = {}
if question in self.cache[language]:
return self.cache[language][question]
result = ""
if method == "bm25":
result = self.find_most_similar_question_bm25(question, language)
elif method == "transformers":
result = self.find_most_similar_question_transformers(question, language)
self.cache[language][question] = result
return result
def answer(self, question, language = "en", method_for_similarity = "bm25"):
question = question.lower()
answer = self.find_answer_to_question(question, language)
if answer is None:
answer = self.random_generator.choice(self.find_most_similar_question(question, language, method_for_similarity).answers) # TODO: GAN
return answer
chatomatic = Chatomatic("test.yml")
# measure how much time it takes to answer a question
import time
start = time.time()
print(chatomatic.answer("great, thanks", method_for_similarity="transformers"))
end = time.time()
print(end - start)
start = time.time()
print(chatomatic.answer("nice, tnxx", method_for_similarity="transformers"))
end = time.time()
print(end - start)
start = time.time()
print(chatomatic.answer("nice, tnxx", method_for_similarity="transformers"))
end = time.time()
print(end - start)