Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support OSS Data Loader #166

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,33 @@ async def upload_data(
task_id=task_id,
input_files=input_files,
filter_pattern=None,
oss_prefix=None,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
)

return {"task_id": task_id}


@router.post("/upload_data_from_oss")
async def upload_oss_data(
moria97 marked this conversation as resolved.
Show resolved Hide resolved
oss_prefix: str = None,
faiss_path: str = None,
enable_raptor: bool = False,
background_tasks: BackgroundTasks = BackgroundTasks(),
):
task_id = uuid.uuid4().hex
background_tasks.add_task(
rag_service.add_knowledge_async,
task_id=task_id,
input_files=None,
filter_pattern=None,
oss_prefix=oss_prefix,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
from_oss=True,
)

return {"task_id": task_id}
14 changes: 10 additions & 4 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from typing import Any
import requests
import html
import httpx
import os
import re
Expand Down Expand Up @@ -92,10 +91,13 @@ def _format_rag_response(
elif is_finished:
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
file_url = doc["metadata"].get("file_url", None)
if filename:
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
if file_url:
formatted_file_name = f"""[{formatted_file_name}]({file_url})"""
referenced_docs += (
f'[{i+1}]: {formatted_file_name} Score:{doc["score"]} \n'
f'[{i+1}]: {formatted_file_name} Score:{doc["score"]}\n'
)

formatted_answer = ""
Expand Down Expand Up @@ -185,7 +187,7 @@ def query_vector(self, text: str):
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=text)
else:
for i, doc in enumerate(response["docs"]):
# html_content = markdown.markdown()
file_url = doc.get("metadata", {}).get("file_url", None)
media_url = doc.get("metadata", {}).get("image_url", None)
if media_url and isinstance(media_url, list):
media_url = "<br>".join(
Expand All @@ -196,7 +198,11 @@ def query_vector(self, text: str):
)
elif media_url:
media_url = f"""<img src="{media_url}"/>"""
safe_html_content = html.escape(doc["text"]).replace("\n", "<br>")
safe_html_content = doc["text"]
if file_url:
safe_html_content = (
f"""<a href="{file_url}">{safe_html_content}</a>"""
)
formatted_text += '<tr style="font-size: 13px;"><td>Doc {}</td><td>{}</td><td>{}</td><td>{}</td></tr>\n'.format(
i + 1, doc["score"], safe_html_content, media_url
)
Expand Down
4 changes: 4 additions & 0 deletions src/pai_rag/config/settings_multi_modal.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ host = "Aliyun-Redis host"
password = "Aliyun-Redis user:pwd"
persist_path = "localdata/storage"

[rag.data_loader]
type = "Local" # [Local, Oss]

[rag.data_reader]
type = "SimpleDirectoryReader"
enable_multimodal = true
Expand Down Expand Up @@ -82,6 +85,7 @@ chunk_overlap = 10
[rag.oss_store]
bucket = ""
endpoint = ""
prefix = ""

[rag.postprocessor]
reranker_type = "simple-weighted-reranker" # [simple-weighted-reranker, model-based-reranker]
Expand Down
37 changes: 34 additions & 3 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import json
import logging
import os
import copy
from uuid import uuid4

DEFAULT_EMPTY_RESPONSE_GEN = "Empty Response"
Expand Down Expand Up @@ -60,8 +61,9 @@ async def aload_knowledge(
enable_raptor=False,
):
sessioned_config = self.config
sessioned_config.rag.data_loader.update({"type": "Local"})
if faiss_path:
sessioned_config = self.config.copy()
sessioned_config = copy.copy(self.config)
sessioned_config.rag.index.update({"persist_path": faiss_path})
self.logger.info(
f"Update rag_application config with faiss_persist_path: {faiss_path}"
Expand All @@ -74,13 +76,42 @@ async def aload_knowledge(
input_files, filter_pattern, enable_qa_extraction, enable_raptor
)

async def aload_knowledge_from_oss(
self,
filter_pattern=None,
oss_prefix=None,
faiss_path=None,
enable_qa_extraction=False,
enable_raptor=False,
):
sessioned_config = copy.copy(self.config)
sessioned_config.rag.data_loader.update({"type": "Oss"})
sessioned_config.rag.oss_store.update({"prefix": oss_prefix})
_ = module_registry.get_module_with_config("OssCacheModule", sessioned_config)
self.logger.info(
f"Update rag_application config with data_loader type: Oss and Oss Bucket prefix: {oss_prefix}"
)
data_loader = module_registry.get_module_with_config(
"DataLoaderModule", sessioned_config
)
if faiss_path:
sessioned_config.rag.index.update({"persist_path": faiss_path})
self.logger.info(
f"Update rag_application config with faiss_persist_path: {faiss_path}"
)
await data_loader.aload(
filter_pattern=filter_pattern,
enable_qa_extraction=enable_qa_extraction,
enable_raptor=enable_raptor,
)

async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RetrievalResponse(docs=[])

sessioned_config = self.config
if query.vector_db and query.vector_db.faiss_path:
sessioned_config = self.config.copy()
sessioned_config = copy.copy(self.config)
sessioned_config.rag.index.update(
{"persist_path": query.vector_db.faiss_path}
)
Expand Down Expand Up @@ -123,7 +154,7 @@ async def aquery_rag(self, query: RagQuery):

sessioned_config = self.config
if query.vector_db and query.vector_db.faiss_path:
sessioned_config = self.config.copy()
sessioned_config = copy.copy(self.config)
sessioned_config.rag.index.update(
{"persist_path": query.vector_db.faiss_path}
)
Expand Down
27 changes: 19 additions & 8 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,34 @@ def check_updates(self):
async def add_knowledge_async(
self,
task_id: str,
input_files: List[str],
input_files: List[str] = None,
filter_pattern: str = None,
oss_prefix: str = None,
faiss_path: str = None,
enable_qa_extraction: bool = False,
enable_raptor: bool = False,
from_oss: bool = False,
):
self.check_updates()
with open(TASK_STATUS_FILE, "a") as f:
f.write(f"{task_id}\tprocessing\n")
try:
await self.rag.aload_knowledge(
input_files,
filter_pattern,
faiss_path,
enable_qa_extraction,
enable_raptor,
)
if not from_oss:
await self.rag.aload_knowledge(
input_files,
filter_pattern,
faiss_path,
enable_qa_extraction,
enable_raptor,
)
else:
await self.rag.aload_knowledge_from_oss(
filter_pattern,
oss_prefix,
faiss_path,
enable_qa_extraction,
enable_raptor,
)
with open(TASK_STATUS_FILE, "a") as f:
f.write(f"{task_id}\tcompleted\n")
except Exception as ex:
Expand Down
22 changes: 20 additions & 2 deletions src/pai_rag/data/rag_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def _get_nodes(
filter_pattern: str,
enable_qa_extraction: bool,
):
tmp_index_doc = self.index.vector_index._docstore.docs
seen_files = set(
[_doc.metadata.get("file_name") for _, _doc in tmp_index_doc.items()]
)

filter_pattern = filter_pattern or "*"
if isinstance(file_path, list):
input_files = [f for f in file_path if os.path.isfile(f)]
Expand All @@ -115,9 +120,22 @@ def _get_nodes(
if len(input_files) == 0:
return

data_reader = self.datareader_factory.get_reader(input_files)
# 检查文件名是否已经在seen_files中,如果在,则跳过当前文件
new_input_files = []
for input_file in input_files:
if os.path.basename(input_file) in seen_files:
print(
f"[RagDataLoader] {os.path.basename(input_file)} already exists, skip it."
)
continue
new_input_files.append(input_file)
if len(new_input_files) == 0:
return
print(f"[RagDataLoader] {len(new_input_files)} files will be loaded.")

data_reader = self.datareader_factory.get_reader(input_files=new_input_files)
docs = data_reader.load_data()
logger.info(f"[DataReader] Loaded {len(docs)} docs.")
print(f"[DataReader] Loaded {len(docs)} docs.")
nodes = []

doc_cnt_map = {}
Expand Down
Loading