forked from NirDiamant/RAG_Techniques
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fusion_retrieval.py
143 lines (113 loc) · 5.45 KB
/
fusion_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import sys
from dotenv import load_dotenv
from langchain.docstore.document import Document
from typing import List
from rank_bm25 import BM25Okapi
import numpy as np
# Add the parent directory to the path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from helper_functions import *
from evaluation.evalute_rag import *
# Load environment variables
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
# Function to encode the PDF to a vector store and return split documents
def encode_pdf_and_get_split_documents(path, chunk_size=1000, chunk_overlap=200):
"""
Encodes a PDF book into a vector store using OpenAI embeddings.
Args:
path: The path to the PDF file.
chunk_size: The desired size of each text chunk.
chunk_overlap: The amount of overlap between consecutive chunks.
Returns:
A FAISS vector store containing the encoded book content.
"""
loader = PyPDFLoader(path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
)
texts = text_splitter.split_documents(documents)
cleaned_texts = replace_t_with_space(texts)
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(cleaned_texts, embeddings)
return vectorstore, cleaned_texts
# Function to create BM25 index for keyword retrieval
def create_bm25_index(documents: List[Document]) -> BM25Okapi:
"""
Create a BM25 index from the given documents.
Args:
documents (List[Document]): List of documents to index.
Returns:
BM25Okapi: An index that can be used for BM25 scoring.
"""
tokenized_docs = [doc.page_content.split() for doc in documents]
return BM25Okapi(tokenized_docs)
# Function for fusion retrieval combining keyword-based (BM25) and vector-based search
def fusion_retrieval(vectorstore, bm25, query: str, k: int = 5, alpha: float = 0.5) -> List[Document]:
"""
Perform fusion retrieval combining keyword-based (BM25) and vector-based search.
Args:
vectorstore (VectorStore): The vectorstore containing the documents.
bm25 (BM25Okapi): Pre-computed BM25 index.
query (str): The query string.
k (int): The number of documents to retrieve.
alpha (float): The weight for vector search scores (1-alpha will be the weight for BM25 scores).
Returns:
List[Document]: The top k documents based on the combined scores.
"""
all_docs = vectorstore.similarity_search("", k=vectorstore.index.ntotal)
bm25_scores = bm25.get_scores(query.split())
vector_results = vectorstore.similarity_search_with_score(query, k=len(all_docs))
vector_scores = np.array([score for _, score in vector_results])
vector_scores = 1 - (vector_scores - np.min(vector_scores)) / (np.max(vector_scores) - np.min(vector_scores))
bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores))
combined_scores = alpha * vector_scores + (1 - alpha) * bm25_scores
sorted_indices = np.argsort(combined_scores)[::-1]
return [all_docs[i] for i in sorted_indices[:k]]
class FusionRetrievalRAG:
def __init__(self, path: str, chunk_size: int = 1000, chunk_overlap: int = 200):
"""
Initializes the FusionRetrievalRAG class by setting up the vector store and BM25 index.
Args:
path (str): Path to the PDF file.
chunk_size (int): The size of each text chunk.
chunk_overlap (int): The overlap between consecutive chunks.
"""
self.vectorstore, self.cleaned_texts = encode_pdf_and_get_split_documents(path, chunk_size, chunk_overlap)
self.bm25 = create_bm25_index(self.cleaned_texts)
def run(self, query: str, k: int = 5, alpha: float = 0.5):
"""
Executes the fusion retrieval for the given query.
Args:
query (str): The search query.
k (int): The number of documents to retrieve.
alpha (float): The weight of vector search vs. BM25 search.
Returns:
List[Document]: The top k retrieved documents.
"""
top_docs = fusion_retrieval(self.vectorstore, self.bm25, query, k, alpha)
docs_content = [doc.page_content for doc in top_docs]
show_context(docs_content)
def parse_args():
"""
Parses command-line arguments.
Returns:
args: The parsed arguments.
"""
import argparse
parser = argparse.ArgumentParser(description="Fusion Retrieval RAG Script")
parser.add_argument('--path', type=str, default="../data/Understanding_Climate_Change.pdf",
help='Path to the PDF file.')
parser.add_argument('--chunk_size', type=int, default=1000, help='Size of each chunk.')
parser.add_argument('--chunk_overlap', type=int, default=200, help='Overlap between consecutive chunks.')
parser.add_argument('--query', type=str, default='What are the impacts of climate change on the environment?',
help='Query to retrieve documents.')
parser.add_argument('--k', type=int, default=5, help='Number of documents to retrieve.')
parser.add_argument('--alpha', type=float, default=0.5, help='Weight for vector search vs. BM25.')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
retriever = FusionRetrievalRAG(path=args.path, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
retriever.run(query=args.query, k=args.k, alpha=args.alpha)