diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index 3677f3fe..d32c9dc0 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -138,9 +138,33 @@ async def upload_data( task_id=task_id, input_files=input_files, filter_pattern=None, + oss_prefix=None, faiss_path=faiss_path, enable_qa_extraction=False, enable_raptor=enable_raptor, ) return {"task_id": task_id} + + +@router.post("/upload_data_from_oss") +async def upload_oss_data( + oss_prefix: str = None, + faiss_path: str = None, + enable_raptor: bool = False, + background_tasks: BackgroundTasks = BackgroundTasks(), +): + task_id = uuid.uuid4().hex + background_tasks.add_task( + rag_service.add_knowledge_async, + task_id=task_id, + input_files=None, + filter_pattern=None, + oss_prefix=oss_prefix, + faiss_path=faiss_path, + enable_qa_extraction=False, + enable_raptor=enable_raptor, + from_oss=True, + ) + + return {"task_id": task_id} diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 0f0a3879..ab4216f0 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -1,7 +1,6 @@ import json from typing import Any import requests -import html import httpx import os import re @@ -92,10 +91,13 @@ def _format_rag_response( elif is_finished: for i, doc in enumerate(docs): filename = doc["metadata"].get("file_name", None) + file_url = doc["metadata"].get("file_url", None) if filename: formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename) + if file_url: + formatted_file_name = f"""[{formatted_file_name}]({file_url})""" referenced_docs += ( - f'[{i+1}]: {formatted_file_name} Score:{doc["score"]} \n' + f'[{i+1}]: {formatted_file_name} Score:{doc["score"]}\n' ) formatted_answer = "" @@ -185,7 +187,7 @@ def query_vector(self, text: str): response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=text) else: for i, doc in enumerate(response["docs"]): - # html_content = markdown.markdown() + file_url = doc.get("metadata", {}).get("file_url", None) media_url = doc.get("metadata", {}).get("image_url", None) if media_url and isinstance(media_url, list): media_url = "
".join( @@ -196,7 +198,11 @@ def query_vector(self, text: str): ) elif media_url: media_url = f"""""" - safe_html_content = html.escape(doc["text"]).replace("\n", "
") + safe_html_content = doc["text"] + if file_url: + safe_html_content = ( + f"""{safe_html_content}""" + ) formatted_text += 'Doc {}{}{}{}\n'.format( i + 1, doc["score"], safe_html_content, media_url ) diff --git a/src/pai_rag/config/settings_multi_modal.toml b/src/pai_rag/config/settings_multi_modal.toml index b6407129..601a9ca0 100644 --- a/src/pai_rag/config/settings_multi_modal.toml +++ b/src/pai_rag/config/settings_multi_modal.toml @@ -25,6 +25,9 @@ host = "Aliyun-Redis host" password = "Aliyun-Redis user:pwd" persist_path = "localdata/storage" +[rag.data_loader] +type = "Local" # [Local, Oss] + [rag.data_reader] type = "SimpleDirectoryReader" enable_multimodal = true @@ -82,6 +85,7 @@ chunk_overlap = 10 [rag.oss_store] bucket = "" endpoint = "" +prefix = "" [rag.postprocessor] reranker_type = "simple-weighted-reranker" # [simple-weighted-reranker, model-based-reranker] diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index be4d69c2..0c978c4f 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -11,6 +11,7 @@ import json import logging import os +import copy from uuid import uuid4 DEFAULT_EMPTY_RESPONSE_GEN = "Empty Response" @@ -60,8 +61,9 @@ async def aload_knowledge( enable_raptor=False, ): sessioned_config = self.config + sessioned_config.rag.data_loader.update({"type": "Local"}) if faiss_path: - sessioned_config = self.config.copy() + sessioned_config = copy.copy(self.config) sessioned_config.rag.index.update({"persist_path": faiss_path}) self.logger.info( f"Update rag_application config with faiss_persist_path: {faiss_path}" @@ -74,13 +76,42 @@ async def aload_knowledge( input_files, filter_pattern, enable_qa_extraction, enable_raptor ) + async def aload_knowledge_from_oss( + self, + filter_pattern=None, + oss_prefix=None, + faiss_path=None, + enable_qa_extraction=False, + enable_raptor=False, + ): + sessioned_config = copy.copy(self.config) + sessioned_config.rag.data_loader.update({"type": "Oss"}) + sessioned_config.rag.oss_store.update({"prefix": oss_prefix}) + _ = module_registry.get_module_with_config("OssCacheModule", sessioned_config) + self.logger.info( + f"Update rag_application config with data_loader type: Oss and Oss Bucket prefix: {oss_prefix}" + ) + data_loader = module_registry.get_module_with_config( + "DataLoaderModule", sessioned_config + ) + if faiss_path: + sessioned_config.rag.index.update({"persist_path": faiss_path}) + self.logger.info( + f"Update rag_application config with faiss_persist_path: {faiss_path}" + ) + await data_loader.aload( + filter_pattern=filter_pattern, + enable_qa_extraction=enable_qa_extraction, + enable_raptor=enable_raptor, + ) + async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse: if not query.question: return RetrievalResponse(docs=[]) sessioned_config = self.config if query.vector_db and query.vector_db.faiss_path: - sessioned_config = self.config.copy() + sessioned_config = copy.copy(self.config) sessioned_config.rag.index.update( {"persist_path": query.vector_db.faiss_path} ) @@ -123,7 +154,7 @@ async def aquery_rag(self, query: RagQuery): sessioned_config = self.config if query.vector_db and query.vector_db.faiss_path: - sessioned_config = self.config.copy() + sessioned_config = copy.copy(self.config) sessioned_config.rag.index.update( {"persist_path": query.vector_db.faiss_path} ) diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index 4412ee1b..75fe9344 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -99,23 +99,34 @@ def check_updates(self): async def add_knowledge_async( self, task_id: str, - input_files: List[str], + input_files: List[str] = None, filter_pattern: str = None, + oss_prefix: str = None, faiss_path: str = None, enable_qa_extraction: bool = False, enable_raptor: bool = False, + from_oss: bool = False, ): self.check_updates() with open(TASK_STATUS_FILE, "a") as f: f.write(f"{task_id}\tprocessing\n") try: - await self.rag.aload_knowledge( - input_files, - filter_pattern, - faiss_path, - enable_qa_extraction, - enable_raptor, - ) + if not from_oss: + await self.rag.aload_knowledge( + input_files, + filter_pattern, + faiss_path, + enable_qa_extraction, + enable_raptor, + ) + else: + await self.rag.aload_knowledge_from_oss( + filter_pattern, + oss_prefix, + faiss_path, + enable_qa_extraction, + enable_raptor, + ) with open(TASK_STATUS_FILE, "a") as f: f.write(f"{task_id}\tcompleted\n") except Exception as ex: diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py index 38bca41e..07b3eeb4 100644 --- a/src/pai_rag/data/rag_dataloader.py +++ b/src/pai_rag/data/rag_dataloader.py @@ -99,6 +99,11 @@ def _get_nodes( filter_pattern: str, enable_qa_extraction: bool, ): + tmp_index_doc = self.index.vector_index._docstore.docs + seen_files = set( + [_doc.metadata.get("file_name") for _, _doc in tmp_index_doc.items()] + ) + filter_pattern = filter_pattern or "*" if isinstance(file_path, list): input_files = [f for f in file_path if os.path.isfile(f)] @@ -115,9 +120,22 @@ def _get_nodes( if len(input_files) == 0: return - data_reader = self.datareader_factory.get_reader(input_files) + # 检查文件名是否已经在seen_files中,如果在,则跳过当前文件 + new_input_files = [] + for input_file in input_files: + if os.path.basename(input_file) in seen_files: + print( + f"[RagDataLoader] {os.path.basename(input_file)} already exists, skip it." + ) + continue + new_input_files.append(input_file) + if len(new_input_files) == 0: + return + print(f"[RagDataLoader] {len(new_input_files)} files will be loaded.") + + data_reader = self.datareader_factory.get_reader(input_files=new_input_files) docs = data_reader.load_data() - logger.info(f"[DataReader] Loaded {len(docs)} docs.") + print(f"[DataReader] Loaded {len(docs)} docs.") nodes = [] doc_cnt_map = {} diff --git a/src/pai_rag/data/rag_oss_dataloader.py b/src/pai_rag/data/rag_oss_dataloader.py new file mode 100644 index 00000000..4f8c3457 --- /dev/null +++ b/src/pai_rag/data/rag_oss_dataloader.py @@ -0,0 +1,371 @@ +import datetime +import json +import os +from typing import Any, Dict, List +from fastapi.concurrency import run_in_threadpool +from llama_index.core import Settings +from llama_index.core.schema import TextNode, ImageNode, ImageDocument +from llama_index.llms.huggingface import HuggingFaceLLM + +from pai_rag.integrations.nodeparsers.base import MarkdownNodeParser +from pai_rag.integrations.extractors.html_qa_extractor import HtmlQAExtractor +from pai_rag.integrations.extractors.text_qa_extractor import TextQAExtractor +from pai_rag.modules.nodeparser.node_parser import node_id_hash +from pai_rag.data.open_dataset import MiraclOpenDataSet, DuRetrievalDataSet + + +import logging +import re + +logger = logging.getLogger(__name__) + +DEFAULT_LOCAL_QA_MODEL_PATH = "./model_repository/qwen_1.8b" +DOC_TYPES_DO_NOT_NEED_CHUNKING = set( + [".csv", ".xlsx", ".xls", ".htm", ".html", ".jsonl"] +) +IMAGE_FILE_TYPES = set([".jpg", ".jpeg", ".png"]) + +IMAGE_URL_REGEX = re.compile( + r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+\.(?:jpg|jpeg|png)", + re.IGNORECASE, +) + + +class OssDataLoader: + """ + OssDataLoader: + Load data with corresponding data readers according to config. + """ + + def __init__( + self, + datareader_factory, + node_parser, + index, + bm25_index, + oss_cache, + node_enhance, + use_local_qa_model=False, + ): + self.datareader_factory = datareader_factory + self.node_parser = node_parser + self.oss_cache = oss_cache + self.index = index + self.bm25_index = bm25_index + self.node_enhance = node_enhance + + if use_local_qa_model: + # API暂不支持此选项 + self.qa_llm = HuggingFaceLLM( + model_name=DEFAULT_LOCAL_QA_MODEL_PATH, + tokenizer_name=DEFAULT_LOCAL_QA_MODEL_PATH, + ) + else: + self.qa_llm = Settings.llm + + html_extractor = HtmlQAExtractor(llm=self.qa_llm) + txt_extractor = TextQAExtractor(llm=self.qa_llm) + + self.extractors = [html_extractor, txt_extractor] + + logger.info("OssDataLoader initialized.") + + def _extract_file_type(self, metadata: Dict[str, Any]): + file_name = metadata.get("file_name", "dummy.txt") + return os.path.splitext(file_name)[1] + + def _get_oss_files(self): + files = [] + if self.oss_cache: + object_list = self.oss_cache.list_objects() + oss_file_path_dir = "localdata/oss_tmp" + if not os.path.exists(oss_file_path_dir): + os.makedirs(oss_file_path_dir) + for oss_obj in object_list: + if oss_obj.key[:-1] != self.oss_cache.prefix: + try: + set_public = self.oss_cache.put_object_acl( + oss_obj.key, "public-read" + ) + except Exception: + logger.error(f"Failed to set_public document {oss_obj.key}") + if set_public: + save_filename = os.path.join(oss_file_path_dir, oss_obj.key) + self.oss_cache.get_object_to_file( + key=oss_obj.key, filename=save_filename + ) + files.append(save_filename) + else: + logger.error(f"Failed to load document {oss_obj.key}") + return files + + def _get_nodes( + self, + file_path: str | List[str], + filter_pattern: str, + enable_qa_extraction: bool, + ): + tmp_index_doc = self.index.vector_index._docstore.docs + seen_files = set( + [_doc.metadata.get("file_name") for _, _doc in tmp_index_doc.items()] + ) + filter_pattern = filter_pattern or "*" + if isinstance(file_path, list): + input_files = [f for f in file_path if os.path.isfile(f)] + elif isinstance(file_path, str) and os.path.isdir(file_path): + import pathlib + + directory = pathlib.Path(file_path) + input_files = [ + f for f in directory.rglob(filter_pattern) if os.path.isfile(f) + ] + else: + input_files = [file_path] + + if len(input_files) == 0: + return + + # 检查文件名是否已经在seen_files中,如果在,则跳过当前文件 + new_input_files = [] + for input_file in input_files: + if os.path.basename(input_file) in seen_files: + print( + f"[RagOssDataLoader] {os.path.basename(input_file)} already exists, skip it." + ) + continue + new_input_files.append(input_file) + if len(new_input_files) == 0: + return + print(f"[RagOssDataLoader] {len(new_input_files)} files will be loaded.") + + data_reader = self.datareader_factory.get_reader(new_input_files) + docs = data_reader.load_data() + logger.info(f"[DataReader] Loaded {len(docs)} docs.") + nodes = [] + + doc_cnt_map = {} + for doc in docs: + doc_type = self._extract_file_type(doc.metadata) + doc.metadata["file_path"] = os.path.basename(doc.metadata["file_path"]) + doc.metadata["file_url"] = self.oss_cache.get_obj_key_url( + doc.metadata["file_path"] + ) + doc_key = f"""{doc.metadata.get("file_path", "dummy")}""" + if doc_key not in doc_cnt_map: + doc_cnt_map[doc_key] = 0 + + if isinstance(doc, ImageDocument): + node_id = node_id_hash(doc_cnt_map[doc_key], doc) + doc_cnt_map[doc_key] += 1 + nodes.append( + ImageNode( + id_=node_id, image_url=doc.image_url, metadata=doc.metadata + ) + ) + elif doc_type in DOC_TYPES_DO_NOT_NEED_CHUNKING: + doc_key = f"""{doc.metadata.get("file_path", "dummy")}""" + doc_cnt_map[doc_key] += 1 + node_id = node_id_hash(doc_cnt_map[doc_key], doc) + nodes.append( + TextNode(id_=node_id, text=doc.text, metadata=doc.metadata) + ) + elif doc_type == ".md" or doc_type == ".pdf": + md_node_parser = MarkdownNodeParser(id_func=node_id_hash) + tmp_nodes = md_node_parser.get_nodes_from_documents([doc]) + for node in tmp_nodes: + node.id_ = node_id_hash(doc_cnt_map[doc_key], doc) + doc_cnt_map[doc_key] += 1 + nodes.append(node) + else: + nodes.extend(self.node_parser.get_nodes_from_documents([doc])) + + for node in nodes: + node.excluded_embed_metadata_keys.append("file_path") + node.excluded_embed_metadata_keys.append("image_url") + node.excluded_embed_metadata_keys.append("total_pages") + node.excluded_embed_metadata_keys.append("source") + + logger.info(f"[DataReader] Split into {len(nodes)} nodes.") + + # QA metadata extraction + if enable_qa_extraction: + qa_nodes = [] + + for extractor in self.extractors: + metadata_list = extractor.extract(nodes) + for i, node in enumerate(nodes): + qa_extraction_result = metadata_list[i].get( + "qa_extraction_result", {} + ) + q_cnt = 0 + metadata = node.metadata + for q, a in qa_extraction_result.items(): + metadata["answer"] = a + qa_nodes.append( + TextNode( + id_=f"{node.id_}_{q_cnt}", text=q, metadata=metadata + ) + ) + q_cnt += 1 + for node in qa_nodes: + node.excluded_embed_metadata_keys.append("answer") + node.excluded_llm_metadata_keys.append("question") + nodes.extend(qa_nodes) + + return nodes + + def load( + self, + filter_pattern: str, + enable_qa_extraction: bool, + enable_raptor: bool, + ): + file_path = self._get_oss_files() + nodes = self._get_nodes(file_path, filter_pattern, enable_qa_extraction) + + if not nodes: + logger.warning("[DataReader] no nodes parsed.") + return + + logger.info("[DataReader] Start inserting to index.") + + if enable_raptor: + nodes_with_embeddings = self.node_enhance(nodes=nodes) + self.index.vector_index.insert_nodes(nodes_with_embeddings) + + logger.info( + f"Inserted {len(nodes)} and enhanced {len(nodes_with_embeddings)-len(nodes)} nodes successfully." + ) + else: + self.index.vector_index.insert_nodes(nodes) + logger.info(f"Inserted {len(nodes)} nodes successfully.") + + self.index.vector_index.storage_context.persist( + persist_dir=self.index.persist_path + ) + + index_metadata_file = os.path.join(self.index.persist_path, "index.metadata") + if self.bm25_index: + self.bm25_index.add_docs(nodes) + metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"}) + with open(index_metadata_file, "w") as wf: + wf.write(metadata_str) + + return + + async def aload( + self, + filter_pattern: str, + enable_qa_extraction: bool, + enable_raptor: bool, + ): + file_path = self._get_oss_files() + nodes = await run_in_threadpool( + lambda: self._get_nodes(file_path, filter_pattern, enable_qa_extraction) + ) + if not nodes: + logger.info("[DataReader] could not find files") + return + + logger.info("[DataReader] Start inserting to index.") + + if enable_raptor: + nodes_with_embeddings = await self.node_enhance.acall(nodes=nodes) + self.index.vector_index.insert_nodes(nodes_with_embeddings) + + logger.info( + f"Async inserted {len(nodes)} and enhanced {len(nodes_with_embeddings)-len(nodes)} nodes successfully." + ) + + else: + self.index.vector_index.insert_nodes(nodes) + logger.info(f"Inserted {len(nodes)} nodes successfully.") + + self.index.vector_index.storage_context.persist( + persist_dir=self.index.persist_path + ) + + index_metadata_file = os.path.join(self.index.persist_path, "index.metadata") + if self.bm25_index: + await run_in_threadpool(lambda: self.bm25_index.add_docs(nodes)) + metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"}) + with open(index_metadata_file, "w") as wf: + wf.write(metadata_str) + + return + + def load_eval_data(self, name: str): + logger.info("[DataReader-Evaluation Dataset]") + if name == "miracl": + miracl_dataset = MiraclOpenDataSet() + miracl_nodes, _ = miracl_dataset.load_related_corpus() + nodes = [] + for node in miracl_nodes: + node_metadata = { + "title": node[2], + "file_path": node[3], + "file_name": node[3], + } + nodes.append( + TextNode(id_=node[0], text=node[1], metadata=node_metadata) + ) + + print(f"[DataReader-Evaluation Dataset] Split into {len(nodes)} nodes.") + + print("[DataReader-Evaluation Dataset] Start inserting to index.") + + self.index.vector_index.insert_nodes(nodes) + self.index.vector_index.storage_context.persist( + persist_dir=self.index.persist_path + ) + + index_metadata_file = os.path.join( + self.index.persist_path, "index.metadata" + ) + if self.bm25_index: + self.bm25_index.add_docs(nodes) + metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"}) + with open(index_metadata_file, "w") as wf: + wf.write(metadata_str) + + print( + f"[DataReader-Evaluation Dataset] Inserted {len(nodes)} nodes successfully." + ) + return + elif name == "duretrieval": + duretrieval_dataset = DuRetrievalDataSet() + miracl_nodes, _, _ = duretrieval_dataset.load_related_corpus() + nodes = [] + for node in miracl_nodes: + node_metadata = { + "file_path": node[2], + "file_name": node[2], + } + nodes.append( + TextNode(id_=node[0], text=node[1], metadata=node_metadata) + ) + + print(f"[DataReader-Evaluation Dataset] Split into {len(nodes)} nodes.") + + print("[DataReader-Evaluation Dataset] Start inserting to index.") + + self.index.vector_index.insert_nodes(nodes) + self.index.vector_index.storage_context.persist( + persist_dir=self.index.persist_path + ) + + index_metadata_file = os.path.join( + self.index.persist_path, "index.metadata" + ) + if self.bm25_index: + self.bm25_index.add_docs(nodes) + metadata_str = json.dumps({"lastUpdated": f"{datetime.datetime.now()}"}) + with open(index_metadata_file, "w") as wf: + wf.write(metadata_str) + + print( + f"[DataReader-Evaluation Dataset] Inserted {len(nodes)} nodes successfully." + ) + return + else: + raise ValueError(f"Not supported eval dataset name with {name}") diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py index b138e073..8c57dfa6 100644 --- a/src/pai_rag/integrations/readers/pai_pdf_reader.py +++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py @@ -366,7 +366,7 @@ def load( md_content = self.parse_pdf(file_path, "auto") images_with_content = PaiPDFReader.combine_images_with_text(md_content) md_contend_without_images_url = PaiPDFReader.remove_image_paths(md_content) - + print(f"[PaiPDFReader] successfully processed pdf file {file_path}.") docs = [] image_documents = [] text_image_documents = [] @@ -374,6 +374,7 @@ def load( if not isinstance(extra_info, dict): raise TypeError("extra_info must be a dictionary.") if self.enable_multimodal: + print("[PaiPDFReader] Using multimodal.") images_url_set = set() for content, image_urls in images_with_content.items(): images_url_set.update(image_urls) @@ -383,6 +384,7 @@ def load( extra_info={"image_url": image_urls, **extra_info}, ) ) + print("[PaiPDFReader] successfully loaded images with multimodal.") image_documents.extend( ImageDocument( image_url=image_url, @@ -406,4 +408,5 @@ def load( docs.extend(image_documents) docs.extend(text_image_documents) + print(f"[PaiPDFReader] successfully loaded {len(docs)} nodes.") return docs diff --git a/src/pai_rag/modules/cache/oss_cache.py b/src/pai_rag/modules/cache/oss_cache.py index e61dc870..0790ef61 100644 --- a/src/pai_rag/modules/cache/oss_cache.py +++ b/src/pai_rag/modules/cache/oss_cache.py @@ -19,10 +19,13 @@ def _create_new_instance(self, new_params: Dict[str, Any]): if cache_config: oss_bucket = cache_config.get("bucket", None) oss_endpoint = cache_config.get("endpoint", None) + oss_prefix = cache_config.get("prefix", None) if oss_bucket: logger.info(f"Using OSS bucket {oss_bucket} for caching objects.") - return OssClient(bucket_name=oss_bucket, endpoint=oss_endpoint) + return OssClient( + bucket_name=oss_bucket, endpoint=oss_endpoint, prefix=oss_prefix + ) else: logger.info("No OSS config provided. Will not cache objects.") return None diff --git a/src/pai_rag/modules/datareader/data_loader.py b/src/pai_rag/modules/datareader/data_loader.py index 9eb98a42..aa48bf04 100644 --- a/src/pai_rag/modules/datareader/data_loader.py +++ b/src/pai_rag/modules/datareader/data_loader.py @@ -1,6 +1,8 @@ from typing import Any, Dict, List +from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.data.rag_dataloader import RagDataLoader +from pai_rag.data.rag_oss_dataloader import OssDataLoader import logging logger = logging.getLogger(__name__) @@ -19,6 +21,7 @@ def get_dependencies() -> List[str]: ] def _create_new_instance(self, new_params: Dict[str, Any]): + self.loader_config = new_params[MODULE_PARAM_CONFIG] oss_cache = new_params["OssCacheModule"] data_reader_factory = new_params["DataReaderFactoryModule"] node_parser = new_params["NodeParserModule"] @@ -26,6 +29,21 @@ def _create_new_instance(self, new_params: Dict[str, Any]): bm25_index = new_params["BM25IndexModule"] node_enhance = new_params["NodesEnhancementModule"] - return RagDataLoader( - data_reader_factory, node_parser, index, bm25_index, oss_cache, node_enhance - ) + if self.loader_config["type"].lower() == "local": + return RagDataLoader( + data_reader_factory, + node_parser, + index, + bm25_index, + oss_cache, + node_enhance, + ) + elif self.loader_config["type"].lower() == "oss": + return OssDataLoader( + data_reader_factory, + node_parser, + index, + bm25_index, + oss_cache, + node_enhance, + ) diff --git a/src/pai_rag/modules/datareader/datareader_factory.py b/src/pai_rag/modules/datareader/datareader_factory.py index 4688d150..1ac40d9b 100644 --- a/src/pai_rag/modules/datareader/datareader_factory.py +++ b/src/pai_rag/modules/datareader/datareader_factory.py @@ -60,7 +60,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): return self - def get_reader(self, input_files: str): + def get_reader(self, input_files: str = None): if self.reader_config["type"] == "SimpleDirectoryReader": return SimpleDirectoryReader( input_files=input_files, diff --git a/src/pai_rag/modules/module_registry.py b/src/pai_rag/modules/module_registry.py index 381755ec..61eb6d6d 100644 --- a/src/pai_rag/modules/module_registry.py +++ b/src/pai_rag/modules/module_registry.py @@ -64,7 +64,7 @@ def _get_param_hash(self, params: Dict[str, Any]): return hashlib.sha256(repr_str).hexdigest() def get_module_with_config(self, module_key, config): - key = repr(config) + key = repr(config.to_dict()) if key in self._cache_by_config and module_key in self._cache_by_config[key]: return self._cache_by_config[key][module_key] @@ -77,7 +77,7 @@ def get_module_with_config(self, module_key, config): return mod def init_modules(self, config): - key = repr(config) + key = repr(config.to_dict()) mod_cache = {} mod_stack = [] diff --git a/src/pai_rag/utils/oss_client.py b/src/pai_rag/utils/oss_client.py index e980d258..20bcff57 100644 --- a/src/pai_rag/utils/oss_client.py +++ b/src/pai_rag/utils/oss_client.py @@ -1,18 +1,20 @@ import logging import hashlib - import oss2 +import os from oss2.credentials import EnvironmentVariableCredentialsProvider logger = logging.getLogger(__name__) class OssClient: - def __init__(self, bucket_name: str, endpoint: str): + def __init__(self, bucket_name: str, endpoint: str, prefix: str = None): auth = oss2.ProviderAuth(EnvironmentVariableCredentialsProvider()) self.bucket_name = bucket_name self.endpoint = endpoint self.base_url = self._make_url() + # 去除prefix可能存在的前后空格,并去除最后的斜杠 + self.prefix = prefix.strip().rstrip("/") """ 确认上面的参数都填写正确了,如果任何一个参数包含 '<',意味着这个参数可能没有被正确设置,而是保留了一个占位符或默认值( @@ -33,6 +35,11 @@ def get_object(self, key: str) -> bytes: logger.info("file does not exist") return None + def get_object_to_file(self, key, filename): + if not os.path.exists(os.path.dirname(filename)): + os.makedirs(os.path.dirname(filename)) + self.bucket.get_object_to_file(key=key, filename=filename) + def put_object(self, key: str, data: bytes, headers=None) -> None: self.bucket.put_object(key, data, headers=headers) @@ -53,8 +60,38 @@ def put_object_if_not_exists( return f"{self.base_url}{key}" + def get_obj_key_url(self, filename: str): + return f"{self.base_url}{self.prefix}/{filename}" + def _make_url(self): base_endpoint = ( self.endpoint.replace("https://", "").replace("http://", "").strip("/") ) return f"https://{self.bucket_name}.{base_endpoint}/" + + def list_objects(self): + """ + 列出存储桶中指定前缀的对象列表。 + + 该方法通过调用oss bucket的list_objects函数,查询与给定前缀匹配的所有对象,并返回这些对象的列表。 + + 参数: + - prefix (str): 对象名的前缀,用于筛选满足条件的对象。默认为空字符串,表示不指定前缀,即列出所有对象。 + + 返回: + - list: 包含满足前缀条件的所有对象的列表。 + """ + # 调用bucket的list_objects方法,传入前缀参数 + res = self.bucket.list_objects(prefix=self.prefix) + # 返回查询到的对象列表 + return res.object_list + + def put_object_acl(self, key, permission): + if key.endswith(".txt"): + self.bucket.update_object_meta( + key, {"Content-Type": "text/plain;charset=utf-8"} + ) + + res = self.bucket.put_object_acl(key=key, permission=permission) + + return res.status == 200