diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py
index 3677f3fe..d32c9dc0 100644
--- a/src/pai_rag/app/api/query.py
+++ b/src/pai_rag/app/api/query.py
@@ -138,9 +138,33 @@ async def upload_data(
task_id=task_id,
input_files=input_files,
filter_pattern=None,
+ oss_prefix=None,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
)
return {"task_id": task_id}
+
+
+@router.post("/upload_data_from_oss")
+async def upload_oss_data(
+ oss_prefix: str = None,
+ faiss_path: str = None,
+ enable_raptor: bool = False,
+ background_tasks: BackgroundTasks = BackgroundTasks(),
+):
+ task_id = uuid.uuid4().hex
+ background_tasks.add_task(
+ rag_service.add_knowledge_async,
+ task_id=task_id,
+ input_files=None,
+ filter_pattern=None,
+ oss_prefix=oss_prefix,
+ faiss_path=faiss_path,
+ enable_qa_extraction=False,
+ enable_raptor=enable_raptor,
+ from_oss=True,
+ )
+
+ return {"task_id": task_id}
diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py
index 0f0a3879..ab4216f0 100644
--- a/src/pai_rag/app/web/rag_client.py
+++ b/src/pai_rag/app/web/rag_client.py
@@ -1,7 +1,6 @@
import json
from typing import Any
import requests
-import html
import httpx
import os
import re
@@ -92,10 +91,13 @@ def _format_rag_response(
elif is_finished:
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
+ file_url = doc["metadata"].get("file_url", None)
if filename:
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
+ if file_url:
+ formatted_file_name = f"""[{formatted_file_name}]({file_url})"""
referenced_docs += (
- f'[{i+1}]: {formatted_file_name} Score:{doc["score"]} \n'
+ f'[{i+1}]: {formatted_file_name} Score:{doc["score"]}\n'
)
formatted_answer = ""
@@ -185,7 +187,7 @@ def query_vector(self, text: str):
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=text)
else:
for i, doc in enumerate(response["docs"]):
- # html_content = markdown.markdown()
+ file_url = doc.get("metadata", {}).get("file_url", None)
media_url = doc.get("metadata", {}).get("image_url", None)
if media_url and isinstance(media_url, list):
media_url = "
".join(
@@ -196,7 +198,11 @@ def query_vector(self, text: str):
)
elif media_url:
media_url = f""""""
- safe_html_content = html.escape(doc["text"]).replace("\n", "
")
+ safe_html_content = doc["text"]
+ if file_url:
+ safe_html_content = (
+ f"""{safe_html_content}"""
+ )
formatted_text += '
Doc {} | {} | {} | {} |
\n'.format(
i + 1, doc["score"], safe_html_content, media_url
)
diff --git a/src/pai_rag/config/settings_multi_modal.toml b/src/pai_rag/config/settings_multi_modal.toml
index b6407129..601a9ca0 100644
--- a/src/pai_rag/config/settings_multi_modal.toml
+++ b/src/pai_rag/config/settings_multi_modal.toml
@@ -25,6 +25,9 @@ host = "Aliyun-Redis host"
password = "Aliyun-Redis user:pwd"
persist_path = "localdata/storage"
+[rag.data_loader]
+type = "Local" # [Local, Oss]
+
[rag.data_reader]
type = "SimpleDirectoryReader"
enable_multimodal = true
@@ -82,6 +85,7 @@ chunk_overlap = 10
[rag.oss_store]
bucket = ""
endpoint = ""
+prefix = ""
[rag.postprocessor]
reranker_type = "simple-weighted-reranker" # [simple-weighted-reranker, model-based-reranker]
diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py
index be4d69c2..0c978c4f 100644
--- a/src/pai_rag/core/rag_application.py
+++ b/src/pai_rag/core/rag_application.py
@@ -11,6 +11,7 @@
import json
import logging
import os
+import copy
from uuid import uuid4
DEFAULT_EMPTY_RESPONSE_GEN = "Empty Response"
@@ -60,8 +61,9 @@ async def aload_knowledge(
enable_raptor=False,
):
sessioned_config = self.config
+ sessioned_config.rag.data_loader.update({"type": "Local"})
if faiss_path:
- sessioned_config = self.config.copy()
+ sessioned_config = copy.copy(self.config)
sessioned_config.rag.index.update({"persist_path": faiss_path})
self.logger.info(
f"Update rag_application config with faiss_persist_path: {faiss_path}"
@@ -74,13 +76,42 @@ async def aload_knowledge(
input_files, filter_pattern, enable_qa_extraction, enable_raptor
)
+ async def aload_knowledge_from_oss(
+ self,
+ filter_pattern=None,
+ oss_prefix=None,
+ faiss_path=None,
+ enable_qa_extraction=False,
+ enable_raptor=False,
+ ):
+ sessioned_config = copy.copy(self.config)
+ sessioned_config.rag.data_loader.update({"type": "Oss"})
+ sessioned_config.rag.oss_store.update({"prefix": oss_prefix})
+ _ = module_registry.get_module_with_config("OssCacheModule", sessioned_config)
+ self.logger.info(
+ f"Update rag_application config with data_loader type: Oss and Oss Bucket prefix: {oss_prefix}"
+ )
+ data_loader = module_registry.get_module_with_config(
+ "DataLoaderModule", sessioned_config
+ )
+ if faiss_path:
+ sessioned_config.rag.index.update({"persist_path": faiss_path})
+ self.logger.info(
+ f"Update rag_application config with faiss_persist_path: {faiss_path}"
+ )
+ await data_loader.aload(
+ filter_pattern=filter_pattern,
+ enable_qa_extraction=enable_qa_extraction,
+ enable_raptor=enable_raptor,
+ )
+
async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RetrievalResponse(docs=[])
sessioned_config = self.config
if query.vector_db and query.vector_db.faiss_path:
- sessioned_config = self.config.copy()
+ sessioned_config = copy.copy(self.config)
sessioned_config.rag.index.update(
{"persist_path": query.vector_db.faiss_path}
)
@@ -123,7 +154,7 @@ async def aquery_rag(self, query: RagQuery):
sessioned_config = self.config
if query.vector_db and query.vector_db.faiss_path:
- sessioned_config = self.config.copy()
+ sessioned_config = copy.copy(self.config)
sessioned_config.rag.index.update(
{"persist_path": query.vector_db.faiss_path}
)
diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py
index 4412ee1b..75fe9344 100644
--- a/src/pai_rag/core/rag_service.py
+++ b/src/pai_rag/core/rag_service.py
@@ -99,23 +99,34 @@ def check_updates(self):
async def add_knowledge_async(
self,
task_id: str,
- input_files: List[str],
+ input_files: List[str] = None,
filter_pattern: str = None,
+ oss_prefix: str = None,
faiss_path: str = None,
enable_qa_extraction: bool = False,
enable_raptor: bool = False,
+ from_oss: bool = False,
):
self.check_updates()
with open(TASK_STATUS_FILE, "a") as f:
f.write(f"{task_id}\tprocessing\n")
try:
- await self.rag.aload_knowledge(
- input_files,
- filter_pattern,
- faiss_path,
- enable_qa_extraction,
- enable_raptor,
- )
+ if not from_oss:
+ await self.rag.aload_knowledge(
+ input_files,
+ filter_pattern,
+ faiss_path,
+ enable_qa_extraction,
+ enable_raptor,
+ )
+ else:
+ await self.rag.aload_knowledge_from_oss(
+ filter_pattern,
+ oss_prefix,
+ faiss_path,
+ enable_qa_extraction,
+ enable_raptor,
+ )
with open(TASK_STATUS_FILE, "a") as f:
f.write(f"{task_id}\tcompleted\n")
except Exception as ex:
diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py
index 38bca41e..07b3eeb4 100644
--- a/src/pai_rag/data/rag_dataloader.py
+++ b/src/pai_rag/data/rag_dataloader.py
@@ -99,6 +99,11 @@ def _get_nodes(
filter_pattern: str,
enable_qa_extraction: bool,
):
+ tmp_index_doc = self.index.vector_index._docstore.docs
+ seen_files = set(
+ [_doc.metadata.get("file_name") for _, _doc in tmp_index_doc.items()]
+ )
+
filter_pattern = filter_pattern or "*"
if isinstance(file_path, list):
input_files = [f for f in file_path if os.path.isfile(f)]
@@ -115,9 +120,22 @@ def _get_nodes(
if len(input_files) == 0:
return
- data_reader = self.datareader_factory.get_reader(input_files)
+ # 检查文件名是否已经在seen_files中,如果在,则跳过当前文件
+ new_input_files = []
+ for input_file in input_files:
+ if os.path.basename(input_file) in seen_files:
+ print(
+ f"[RagDataLoader] {os.path.basename(input_file)} already exists, skip it."
+ )
+ continue
+ new_input_files.append(input_file)
+ if len(new_input_files) == 0:
+ return
+ print(f"[RagDataLoader] {len(new_input_files)} files will be loaded.")
+
+ data_reader = self.datareader_factory.get_reader(input_files=new_input_files)
docs = data_reader.load_data()
- logger.info(f"[DataReader] Loaded {len(docs)} docs.")
+ print(f"[DataReader] Loaded {len(docs)} docs.")
nodes = []
doc_cnt_map = {}
diff --git a/src/pai_rag/data/rag_oss_dataloader.py b/src/pai_rag/data/rag_oss_dataloader.py
new file mode 100644
index 00000000..4f8c3457
--- /dev/null
+++ b/src/pai_rag/data/rag_oss_dataloader.py
@@ -0,0 +1,371 @@
+import datetime
+import json
+import os
+from typing import Any, Dict, List
+from fastapi.concurrency import run_in_threadpool
+from llama_index.core import Settings
+from llama_index.core.schema import TextNode, ImageNode, ImageDocument
+from llama_index.llms.huggingface import HuggingFaceLLM
+
+from pai_rag.integrations.nodeparsers.base import MarkdownNodeParser
+from pai_rag.integrations.extractors.html_qa_extractor import HtmlQAExtractor
+from pai_rag.integrations.extractors.text_qa_extractor import TextQAExtractor
+from pai_rag.modules.nodeparser.node_parser import node_id_hash
+from pai_rag.data.open_dataset import MiraclOpenDataSet, DuRetrievalDataSet
+
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_LOCAL_QA_MODEL_PATH = "./model_repository/qwen_1.8b"
+DOC_TYPES_DO_NOT_NEED_CHUNKING = set(
+ [".csv", ".xlsx", ".xls", ".htm", ".html", ".jsonl"]
+)
+IMAGE_FILE_TYPES = set([".jpg", ".jpeg", ".png"])
+
+IMAGE_URL_REGEX = re.compile(
+ r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+\.(?:jpg|jpeg|png)",
+ re.IGNORECASE,
+)
+
+
+class OssDataLoader:
+ """
+ OssDataLoader:
+ Load data with corresponding data readers according to config.
+ """
+
+ def __init__(
+ self,
+ datareader_factory,
+ node_parser,
+ index,
+ bm25_index,
+ oss_cache,
+ node_enhance,
+ use_local_qa_model=False,
+ ):
+ self.datareader_factory = datareader_factory
+ self.node_parser = node_parser
+ self.oss_cache = oss_cache
+ self.index = index
+ self.bm25_index = bm25_index
+ self.node_enhance = node_enhance
+
+ if use_local_qa_model:
+ # API暂不支持此选项
+ self.qa_llm = HuggingFaceLLM(
+ model_name=DEFAULT_LOCAL_QA_MODEL_PATH,
+ tokenizer_name=DEFAULT_LOCAL_QA_MODEL_PATH,
+ )
+ else:
+ self.qa_llm = Settings.llm
+
+ html_extractor = HtmlQAExtractor(llm=self.qa_llm)
+ txt_extractor = TextQAExtractor(llm=self.qa_llm)
+
+ self.extractors = [html_extractor, txt_extractor]
+
+ logger.info("OssDataLoader initialized.")
+
+ def _extract_file_type(self, metadata: Dict[str, Any]):
+ file_name = metadata.get("file_name", "dummy.txt")
+ return os.path.splitext(file_name)[1]
+
+ def _get_oss_files(self):
+ files = []
+ if self.oss_cache:
+ object_list = self.oss_cache.list_objects()
+ oss_file_path_dir = "localdata/oss_tmp"
+ if not os.path.exists(oss_file_path_dir):
+ os.makedirs(oss_file_path_dir)
+ for oss_obj in object_list:
+ if oss_obj.key[:-1] != self.oss_cache.prefix:
+ try:
+ set_public = self.oss_cache.put_object_acl(
+ oss_obj.key, "public-read"
+ )
+ except Exception:
+ logger.error(f"Failed to set_public document {oss_obj.key}")
+ if set_public:
+ save_filename = os.path.join(oss_file_path_dir, oss_obj.key)
+ self.oss_cache.get_object_to_file(
+ key=oss_obj.key, filename=save_filename
+ )
+ files.append(save_filename)
+ else:
+ logger.error(f"Failed to load document {oss_obj.key}")
+ return files
+
+ def _get_nodes(
+ self,
+ file_path: str | List[str],
+ filter_pattern: str,
+ enable_qa_extraction: bool,
+ ):
+ tmp_index_doc = self.index.vector_index._docstore.docs
+ seen_files = set(
+ [_doc.metadata.get("file_name") for _, _doc in tmp_index_doc.items()]
+ )
+ filter_pattern = filter_pattern or "*"
+ if isinstance(file_path, list):
+ input_files = [f for f in file_path if os.path.isfile(f)]
+ elif isinstance(file_path, str) and os.path.isdir(file_path):
+ import pathlib
+
+ directory = pathlib.Path(file_path)
+ input_files = [
+ f for f in directory.rglob(filter_pattern) if os.path.isfile(f)
+ ]
+ else:
+ input_files = [file_path]
+
+ if len(input_files) == 0:
+ return
+
+ # 检查文件名是否已经在seen_files中,如果在,则跳过当前文件
+ new_input_files = []
+ for input_file in input_files:
+ if os.path.basename(input_file) in seen_files:
+ print(
+ f"[RagOssDataLoader] {os.path.basename(input_file)} already exists, skip it."
+ )
+ continue
+ new_input_files.append(input_file)
+ if len(new_input_files) == 0:
+ return
+ print(f"[RagOssDataLoader] {len(new_input_files)} files will be loaded.")
+
+ data_reader = self.datareader_factory.get_reader(new_input_files)
+ docs = data_reader.load_data()
+ logger.info(f"[DataReader] Loaded {len(docs)} docs.")
+ nodes = []
+
+ doc_cnt_map = {}
+ for doc in docs:
+ doc_type = self._extract_file_type(doc.metadata)
+ doc.metadata["file_path"] = os.path.basename(doc.metadata["file_path"])
+ doc.metadata["file_url"] = self.oss_cache.get_obj_key_url(
+ doc.metadata["file_path"]
+ )
+ doc_key = f"""{doc.metadata.get("file_path", "dummy")}"""
+ if doc_key not in doc_cnt_map:
+ doc_cnt_map[doc_key] = 0
+
+ if isinstance(doc, ImageDocument):
+ node_id = node_id_hash(doc_cnt_map[doc_key], doc)
+ doc_cnt_map[doc_key] += 1
+ nodes.append(
+ ImageNode(
+ id_=node_id, image_url=doc.image_url, metadata=doc.metadata
+ )
+ )
+ elif doc_type in DOC_TYPES_DO_NOT_NEED_CHUNKING:
+ doc_key = f"""{doc.metadata.get("file_path", "dummy")}"""
+ doc_cnt_map[doc_key] += 1
+ node_id = node_id_hash(doc_cnt_map[doc_key], doc)
+ nodes.append(
+ TextNode(id_=node_id, text=doc.text, metadata=doc.metadata)
+ )
+ elif doc_type == ".md" or doc_type == ".pdf":
+ md_node_parser = MarkdownNodeParser(id_func=node_id_hash)
+ tmp_nodes = md_node_parser.get_nodes_from_documents([doc])
+ for node in tmp_nodes:
+ node.id_ = node_id_hash(doc_cnt_map[doc_key], doc)
+ doc_cnt_map[doc_key] += 1
+ nodes.append(node)
+ else:
+ nodes.extend(self.node_parser.get_nodes_from_documents([doc]))
+
+ for node in nodes:
+ node.excluded_embed_metadata_keys.append("file_path")
+ node.excluded_embed_metadata_keys.append("image_url")
+ node.excluded_embed_metadata_keys.append("total_pages")
+ node.excluded_embed_metadata_keys.append("source")
+
+ logger.info(f"[DataReader] Split into {len(nodes)} nodes.")
+
+ # QA metadata extraction
+ if enable_qa_extraction:
+ qa_nodes = []
+
+ for extractor in self.extractors:
+ metadata_list = extractor.extract(nodes)
+ for i, node in enumerate(nodes):
+ qa_extraction_result = metadata_list[i].get(
+ "qa_extraction_result", {}
+ )
+ q_cnt = 0
+ metadata = node.metadata
+ for q, a in qa_extraction_result.items():
+ metadata["answer"] = a
+ qa_nodes.append(
+ TextNode(
+ id_=f"{node.id_}_{q_cnt}", text=q, metadata=metadata
+ )
+ )
+ q_cnt += 1
+ for node in qa_nodes:
+ node.excluded_embed_metadata_keys.append("answer")
+ node.excluded_llm_metadata_keys.append("question")
+ nodes.extend(qa_nodes)
+
+ return nodes
+
+ def load(
+ self,
+ filter_pattern: str,
+ enable_qa_extraction: bool,
+ enable_raptor: bool,
+ ):
+ file_path = self._get_oss_files()
+ nodes = self._get_nodes(file_path, filter_pattern, enable_qa_extraction)
+
+ if not nodes:
+ logger.warning("[DataReader] no nodes parsed.")
+ return
+
+ logger.info("[DataReader] Start inserting to index.")
+
+ if enable_raptor:
+ nodes_with_embeddings = self.node_enhance(nodes=nodes)
+ self.index.vector_index.insert_nodes(nodes_with_embeddings)
+
+ logger.info(
+ f"Inserted {len(nodes)} and enhanced {len(nodes_with_embeddings)-len(nodes)} nodes successfully."
+ )
+ else:
+ self.index.vector_index.insert_nodes(nodes)
+ logger.info(f"Inserted {len(nodes)} nodes successfully.")
+
+ self.index.vector_index.storage_context.persist(
+ persist_dir=self.index.persist_path
+ )
+
+ index_metadata_file = os.path.join(self.index.persist_path, "index.metadata")
+ if self.bm25_index:
+ self.bm25_index.add_docs(nodes)
+ metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"})
+ with open(index_metadata_file, "w") as wf:
+ wf.write(metadata_str)
+
+ return
+
+ async def aload(
+ self,
+ filter_pattern: str,
+ enable_qa_extraction: bool,
+ enable_raptor: bool,
+ ):
+ file_path = self._get_oss_files()
+ nodes = await run_in_threadpool(
+ lambda: self._get_nodes(file_path, filter_pattern, enable_qa_extraction)
+ )
+ if not nodes:
+ logger.info("[DataReader] could not find files")
+ return
+
+ logger.info("[DataReader] Start inserting to index.")
+
+ if enable_raptor:
+ nodes_with_embeddings = await self.node_enhance.acall(nodes=nodes)
+ self.index.vector_index.insert_nodes(nodes_with_embeddings)
+
+ logger.info(
+ f"Async inserted {len(nodes)} and enhanced {len(nodes_with_embeddings)-len(nodes)} nodes successfully."
+ )
+
+ else:
+ self.index.vector_index.insert_nodes(nodes)
+ logger.info(f"Inserted {len(nodes)} nodes successfully.")
+
+ self.index.vector_index.storage_context.persist(
+ persist_dir=self.index.persist_path
+ )
+
+ index_metadata_file = os.path.join(self.index.persist_path, "index.metadata")
+ if self.bm25_index:
+ await run_in_threadpool(lambda: self.bm25_index.add_docs(nodes))
+ metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"})
+ with open(index_metadata_file, "w") as wf:
+ wf.write(metadata_str)
+
+ return
+
+ def load_eval_data(self, name: str):
+ logger.info("[DataReader-Evaluation Dataset]")
+ if name == "miracl":
+ miracl_dataset = MiraclOpenDataSet()
+ miracl_nodes, _ = miracl_dataset.load_related_corpus()
+ nodes = []
+ for node in miracl_nodes:
+ node_metadata = {
+ "title": node[2],
+ "file_path": node[3],
+ "file_name": node[3],
+ }
+ nodes.append(
+ TextNode(id_=node[0], text=node[1], metadata=node_metadata)
+ )
+
+ print(f"[DataReader-Evaluation Dataset] Split into {len(nodes)} nodes.")
+
+ print("[DataReader-Evaluation Dataset] Start inserting to index.")
+
+ self.index.vector_index.insert_nodes(nodes)
+ self.index.vector_index.storage_context.persist(
+ persist_dir=self.index.persist_path
+ )
+
+ index_metadata_file = os.path.join(
+ self.index.persist_path, "index.metadata"
+ )
+ if self.bm25_index:
+ self.bm25_index.add_docs(nodes)
+ metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"})
+ with open(index_metadata_file, "w") as wf:
+ wf.write(metadata_str)
+
+ print(
+ f"[DataReader-Evaluation Dataset] Inserted {len(nodes)} nodes successfully."
+ )
+ return
+ elif name == "duretrieval":
+ duretrieval_dataset = DuRetrievalDataSet()
+ miracl_nodes, _, _ = duretrieval_dataset.load_related_corpus()
+ nodes = []
+ for node in miracl_nodes:
+ node_metadata = {
+ "file_path": node[2],
+ "file_name": node[2],
+ }
+ nodes.append(
+ TextNode(id_=node[0], text=node[1], metadata=node_metadata)
+ )
+
+ print(f"[DataReader-Evaluation Dataset] Split into {len(nodes)} nodes.")
+
+ print("[DataReader-Evaluation Dataset] Start inserting to index.")
+
+ self.index.vector_index.insert_nodes(nodes)
+ self.index.vector_index.storage_context.persist(
+ persist_dir=self.index.persist_path
+ )
+
+ index_metadata_file = os.path.join(
+ self.index.persist_path, "index.metadata"
+ )
+ if self.bm25_index:
+ self.bm25_index.add_docs(nodes)
+ metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"})
+ with open(index_metadata_file, "w") as wf:
+ wf.write(metadata_str)
+
+ print(
+ f"[DataReader-Evaluation Dataset] Inserted {len(nodes)} nodes successfully."
+ )
+ return
+ else:
+ raise ValueError(f"Not supported eval dataset name with {name}")
diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py
index b138e073..8c57dfa6 100644
--- a/src/pai_rag/integrations/readers/pai_pdf_reader.py
+++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py
@@ -366,7 +366,7 @@ def load(
md_content = self.parse_pdf(file_path, "auto")
images_with_content = PaiPDFReader.combine_images_with_text(md_content)
md_contend_without_images_url = PaiPDFReader.remove_image_paths(md_content)
-
+ print(f"[PaiPDFReader] successfully processed pdf file {file_path}.")
docs = []
image_documents = []
text_image_documents = []
@@ -374,6 +374,7 @@ def load(
if not isinstance(extra_info, dict):
raise TypeError("extra_info must be a dictionary.")
if self.enable_multimodal:
+ print("[PaiPDFReader] Using multimodal.")
images_url_set = set()
for content, image_urls in images_with_content.items():
images_url_set.update(image_urls)
@@ -383,6 +384,7 @@ def load(
extra_info={"image_url": image_urls, **extra_info},
)
)
+ print("[PaiPDFReader] successfully loaded images with multimodal.")
image_documents.extend(
ImageDocument(
image_url=image_url,
@@ -406,4 +408,5 @@ def load(
docs.extend(image_documents)
docs.extend(text_image_documents)
+ print(f"[PaiPDFReader] successfully loaded {len(docs)} nodes.")
return docs
diff --git a/src/pai_rag/modules/cache/oss_cache.py b/src/pai_rag/modules/cache/oss_cache.py
index e61dc870..0790ef61 100644
--- a/src/pai_rag/modules/cache/oss_cache.py
+++ b/src/pai_rag/modules/cache/oss_cache.py
@@ -19,10 +19,13 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
if cache_config:
oss_bucket = cache_config.get("bucket", None)
oss_endpoint = cache_config.get("endpoint", None)
+ oss_prefix = cache_config.get("prefix", None)
if oss_bucket:
logger.info(f"Using OSS bucket {oss_bucket} for caching objects.")
- return OssClient(bucket_name=oss_bucket, endpoint=oss_endpoint)
+ return OssClient(
+ bucket_name=oss_bucket, endpoint=oss_endpoint, prefix=oss_prefix
+ )
else:
logger.info("No OSS config provided. Will not cache objects.")
return None
diff --git a/src/pai_rag/modules/datareader/data_loader.py b/src/pai_rag/modules/datareader/data_loader.py
index 9eb98a42..aa48bf04 100644
--- a/src/pai_rag/modules/datareader/data_loader.py
+++ b/src/pai_rag/modules/datareader/data_loader.py
@@ -1,6 +1,8 @@
from typing import Any, Dict, List
+from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.data.rag_dataloader import RagDataLoader
+from pai_rag.data.rag_oss_dataloader import OssDataLoader
import logging
logger = logging.getLogger(__name__)
@@ -19,6 +21,7 @@ def get_dependencies() -> List[str]:
]
def _create_new_instance(self, new_params: Dict[str, Any]):
+ self.loader_config = new_params[MODULE_PARAM_CONFIG]
oss_cache = new_params["OssCacheModule"]
data_reader_factory = new_params["DataReaderFactoryModule"]
node_parser = new_params["NodeParserModule"]
@@ -26,6 +29,21 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
bm25_index = new_params["BM25IndexModule"]
node_enhance = new_params["NodesEnhancementModule"]
- return RagDataLoader(
- data_reader_factory, node_parser, index, bm25_index, oss_cache, node_enhance
- )
+ if self.loader_config["type"].lower() == "local":
+ return RagDataLoader(
+ data_reader_factory,
+ node_parser,
+ index,
+ bm25_index,
+ oss_cache,
+ node_enhance,
+ )
+ elif self.loader_config["type"].lower() == "oss":
+ return OssDataLoader(
+ data_reader_factory,
+ node_parser,
+ index,
+ bm25_index,
+ oss_cache,
+ node_enhance,
+ )
diff --git a/src/pai_rag/modules/datareader/datareader_factory.py b/src/pai_rag/modules/datareader/datareader_factory.py
index 4688d150..1ac40d9b 100644
--- a/src/pai_rag/modules/datareader/datareader_factory.py
+++ b/src/pai_rag/modules/datareader/datareader_factory.py
@@ -60,7 +60,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]):
return self
- def get_reader(self, input_files: str):
+ def get_reader(self, input_files: str = None):
if self.reader_config["type"] == "SimpleDirectoryReader":
return SimpleDirectoryReader(
input_files=input_files,
diff --git a/src/pai_rag/modules/module_registry.py b/src/pai_rag/modules/module_registry.py
index 381755ec..61eb6d6d 100644
--- a/src/pai_rag/modules/module_registry.py
+++ b/src/pai_rag/modules/module_registry.py
@@ -64,7 +64,7 @@ def _get_param_hash(self, params: Dict[str, Any]):
return hashlib.sha256(repr_str).hexdigest()
def get_module_with_config(self, module_key, config):
- key = repr(config)
+ key = repr(config.to_dict())
if key in self._cache_by_config and module_key in self._cache_by_config[key]:
return self._cache_by_config[key][module_key]
@@ -77,7 +77,7 @@ def get_module_with_config(self, module_key, config):
return mod
def init_modules(self, config):
- key = repr(config)
+ key = repr(config.to_dict())
mod_cache = {}
mod_stack = []
diff --git a/src/pai_rag/utils/oss_client.py b/src/pai_rag/utils/oss_client.py
index e980d258..20bcff57 100644
--- a/src/pai_rag/utils/oss_client.py
+++ b/src/pai_rag/utils/oss_client.py
@@ -1,18 +1,20 @@
import logging
import hashlib
-
import oss2
+import os
from oss2.credentials import EnvironmentVariableCredentialsProvider
logger = logging.getLogger(__name__)
class OssClient:
- def __init__(self, bucket_name: str, endpoint: str):
+ def __init__(self, bucket_name: str, endpoint: str, prefix: str = None):
auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider())
self.bucket_name = bucket_name
self.endpoint = endpoint
self.base_url = self._make_url()
+ # 去除prefix可能存在的前后空格,并去除最后的斜杠
+ self.prefix = prefix.strip().rstrip("/")
"""
确认上面的参数都填写正确了,如果任何一个参数包含 '<',意味着这个参数可能没有被正确设置,而是保留了一个占位符或默认值(
@@ -33,6 +35,11 @@ def get_object(self, key: str) -> bytes:
logger.info("file does not exist")
return None
+ def get_object_to_file(self, key, filename):
+ if not os.path.exists(os.path.dirname(filename)):
+ os.makedirs(os.path.dirname(filename))
+ self.bucket.get_object_to_file(key=key, filename=filename)
+
def put_object(self, key: str, data: bytes, headers=None) -> None:
self.bucket.put_object(key, data, headers=headers)
@@ -53,8 +60,38 @@ def put_object_if_not_exists(
return f"{self.base_url}{key}"
+ def get_obj_key_url(self, filename: str):
+ return f"{self.base_url}{self.prefix}/{filename}"
+
def _make_url(self):
base_endpoint = (
self.endpoint.replace("https://", "").replace("http://", "").strip("/")
)
return f"https://{self.bucket_name}.{base_endpoint}/"
+
+ def list_objects(self):
+ """
+ 列出存储桶中指定前缀的对象列表。
+
+ 该方法通过调用oss bucket的list_objects函数,查询与给定前缀匹配的所有对象,并返回这些对象的列表。
+
+ 参数:
+ - prefix (str): 对象名的前缀,用于筛选满足条件的对象。默认为空字符串,表示不指定前缀,即列出所有对象。
+
+ 返回:
+ - list: 包含满足前缀条件的所有对象的列表。
+ """
+ # 调用bucket的list_objects方法,传入前缀参数
+ res = self.bucket.list_objects(prefix=self.prefix)
+ # 返回查询到的对象列表
+ return res.object_list
+
+ def put_object_acl(self, key, permission):
+ if key.endswith(".txt"):
+ self.bucket.update_object_meta(
+ key, {"Content-Type": "text/plain;charset=utf-8"}
+ )
+
+ res = self.bucket.put_object_acl(key=key, permission=permission)
+
+ return res.status == 200