diff --git a/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py b/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py index 7baa6c4c..1a508dce 100644 --- a/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py +++ b/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py @@ -40,10 +40,8 @@ class NodeParserConfig(BaseModel): buffer_size: int = DEFAULT_BUFFER_SIZE -DOC_TYPES_DO_NOT_NEED_CHUNKING = set( - [".csv", ".xlsx", ".xls", ".htm", ".html", ".jsonl"] -) -DOC_TYPES_CONVERT_TO_MD = set([".md", ".pdf", ".docx"]) +DOC_TYPES_DO_NOT_NEED_CHUNKING = set([".csv", ".xlsx", ".xls", ".jsonl"]) +DOC_TYPES_CONVERT_TO_MD = set([".md", ".pdf", ".docx", ".htm", ".html"]) IMAGE_FILE_TYPES = set([".jpg", ".jpeg", ".png"]) IMAGE_URL_REGEX = re.compile( diff --git a/src/pai_rag/integrations/readers/pai/pai_data_reader.py b/src/pai_rag/integrations/readers/pai/pai_data_reader.py index 993e8bf1..49bcef7d 100644 --- a/src/pai_rag/integrations/readers/pai/pai_data_reader.py +++ b/src/pai_rag/integrations/readers/pai/pai_data_reader.py @@ -4,7 +4,7 @@ from pai_rag.integrations.readers.markdown_reader import MarkdownReader from pai_rag.integrations.readers.pai_image_reader import PaiImageReader from pai_rag.integrations.readers.pai_pdf_reader import PaiPDFReader -from pai_rag.integrations.readers.html.html_reader import HtmlReader +from pai_rag.integrations.readers.pai_html_reader import PaiHtmlReader from pai_rag.integrations.readers.pai_csv_reader import PaiPandasCSVReader from pai_rag.integrations.readers.pai_excel_reader import PaiPandasExcelReader from pai_rag.integrations.readers.pai_jsonl_reader import PaiJsonLReader @@ -32,8 +32,14 @@ def get_file_readers(reader_config: BaseDataReaderConfig = None, oss_store: Any image_reader = PaiImageReader(oss_cache=oss_store) file_readers = { - ".html": HtmlReader(), - ".htm": HtmlReader(), + ".html": PaiHtmlReader( + enable_table_summary=reader_config.enable_table_summary, + oss_cache=oss_store, # Storing html images + ), + ".htm": PaiHtmlReader( + enable_table_summary=reader_config.enable_table_summary, + oss_cache=oss_store, # Storing html images + ), ".docx": PaiDocxReader( enable_table_summary=reader_config.enable_table_summary, oss_cache=oss_store, # Storing docx images diff --git a/src/pai_rag/integrations/readers/pai_docx_reader.py b/src/pai_rag/integrations/readers/pai_docx_reader.py index d3923419..1a4ba668 100644 --- a/src/pai_rag/integrations/readers/pai_docx_reader.py +++ b/src/pai_rag/integrations/readers/pai_docx_reader.py @@ -223,7 +223,7 @@ def load_data( metadata: bool = True, extra_info: Optional[Dict] = None, ) -> List[Document]: - """Loads list of documents from PDF file and also accepts extra information in dict format.""" + """Loads list of documents from Docx file and also accepts extra information in dict format.""" return self.load(file_path, metadata=metadata, extra_info=extra_info) def load( diff --git a/src/pai_rag/integrations/readers/pai_html_reader.py b/src/pai_rag/integrations/readers/pai_html_reader.py new file mode 100644 index 00000000..5ca4f8ed --- /dev/null +++ b/src/pai_rag/integrations/readers/pai_html_reader.py @@ -0,0 +1,242 @@ +"""Html parser. + +""" +import html2text +import logging +from bs4 import BeautifulSoup +import requests +from typing import Dict, List, Optional, Union, Any +from io import BytesIO +from pai_rag.utils.markdown_utils import ( + transform_local_to_oss, + convert_table_to_markdown, + PaiTable, +) +from pathlib import Path +import re +import time +import os +from PIL import Image +from llama_index.core.readers.base import BaseReader +from llama_index.core.schema import Document + +logger = logging.getLogger(__name__) + +IMAGE_URL_PATTERN = ( + r"!\[(?P.*?)\]\((https?://[^\s]+?[\s\w.-]*\.(jpg|jpeg|png|gif|bmp))\)" +) + + +class PaiHtmlReader(BaseReader): + """Read html files including texts, tables, images. + + Args: + enable_table_summary (bool): whether to use table_summary to process tables + """ + + def __init__( + self, + enable_table_summary: bool = False, + oss_cache: Any = None, + ) -> None: + self.enable_table_summary = enable_table_summary + self._oss_cache = oss_cache + logger.info( + f"PaiHtmlReader created with enable_table_summary : {self.enable_table_summary}" + ) + + def _extract_tables(self, html): + soup = BeautifulSoup(html, "html.parser") + tables = soup.find_all("table") + for table in tables: + # 替换表格内容为一个占位符 + placeholder = f"" + table.replace_with(placeholder) + return str(soup), tables + + def _convert_table_to_pai_table(self, table): + # 标记header的index + row_headers_index = [] + col_headers_index = [] + row_header_flag = True + col_header_index_max = -1 + table_matrix = [] + current_row_index = 0 + max_cols = 0 + max_rows = 0 + for row in table.find_all("tr"): + current_col_index = 0 + if current_row_index == 0: + row_cells = [] + else: + row_cells = [""] * max_cols + if current_row_index >= max_rows: + table_matrix.append(row_cells) + for cell in row.find_all(["th", "td"]): + if cell.name != "th": + row_header_flag = False + else: + col_header_index_max = max(col_header_index_max, current_col_index) + cell_content = self._parse_cell_content(cell) + col_span = int(cell.get("colspan", 1)) + row_span = int(cell.get("rowspan", 1)) + if current_row_index != 0: + while ( + current_col_index < max_cols + and table_matrix[current_row_index][current_col_index] != "" + ): + current_col_index += 1 + if (current_col_index > max_cols and max_cols != 0) or ( + current_row_index > max_rows and max_rows != 0 + ): + break + for i in range(col_span): + if current_row_index == 0: + table_matrix[current_row_index].append(cell_content) + elif current_col_index + i < max_cols: + table_matrix[current_row_index][ + current_col_index + i + ] = cell_content + for i in range(1, row_span): + if current_row_index + i > max_rows: + table_matrix.append(row_cells) + table_matrix[current_row_index + i][ + current_col_index + ] = cell_content + max_rows = max(current_row_index + row_span, max_rows) + current_col_index += col_span + if current_row_index == 0: + max_cols += col_span + if row_header_flag: + row_headers_index.append(current_row_index) + current_row_index += 1 + + for i in range(col_header_index_max + 1): + col_headers_index.append(i) + + table = PaiTable( + data=table_matrix, + row_headers_index=row_headers_index, + column_headers_index=col_headers_index, + ) + + return table, max_cols + + def _parse_cell_content(self, cell): + content = [] + for element in cell.contents: + if isinstance(element, str): + content.append(element.strip()) + elif element.name == "p": + p_content = [] + for sub_element in element.contents: + if sub_element.name == "img": + image_url = sub_element.get("src") + p_content.append(f"![]({image_url})") + elif isinstance(sub_element, str): + p_content.append(sub_element.strip()) + else: + p_content.append(sub_element.text.strip()) + content.append(" ".join(p_content)) + else: + content.append(element.text.strip()) + return " ".join(content) + + def _convert_table_to_markdown(self, table): + table, total_cols = self._convert_table_to_pai_table(table) + return convert_table_to_markdown(table, total_cols) + + def _transform_local_to_oss(self, html_name: str, image_url: str): + response = requests.get(image_url) + response.raise_for_status() # 检查请求是否成功 + + # 将二进制数据转换为图像对象 + image = Image.open(BytesIO(response.content)) + return transform_local_to_oss(self._oss_cache, image, html_name) + + def _replace_image_paths(self, html_name: str, content: str): + image_pattern = IMAGE_URL_PATTERN + matches = re.findall(image_pattern, content) + for alt_text, image_url, image_type in matches: + time_tag = int(time.time()) + oss_url = self._transform_local_to_oss(html_name, image_url) + updated_alt_text = f"pai_rag_image_{time_tag}_{alt_text}" + content = content.replace( + f"![{alt_text}]({image_url})", f"![{updated_alt_text}]({oss_url})" + ) + + return content + + def convert_html_to_markdown(self, html_path): + html_name = os.path.basename(html_path).split(".")[0] + html_name = html_name.replace(" ", "_") + try: + with open(html_path, "r", encoding="utf-8") as f: + html_content = f.read() + + modified_html, tables = self._extract_tables(html_content) + h = html2text.HTML2Text() + + # 配置 html2text 对象 + h.ignore_links = True # 是否忽略链接 + h.ignore_images = False # 是否忽略图片 + h.escape_all = True # 是否转义所有特殊字符 + h.body_width = 0 # 设置行宽为 0 表示不限制行宽 + + # 将 HTML 转换为 Markdown + markdown_content = h.handle(modified_html) + for table in tables: + table_markdown = self._convert_table_to_markdown(table) + "\n\n" + placeholder = f"" + markdown_content = markdown_content.replace(placeholder, table_markdown) + + markdown_content = self._replace_image_paths(html_name, markdown_content) + + return markdown_content + + except Exception as e: + logger(e) + return None + + def load_data( + self, + file_path: Union[Path, str], + metadata: bool = True, + extra_info: Optional[Dict] = None, + ) -> List[Document]: + """Loads list of documents from Html file and also accepts extra information in dict format.""" + return self.load(file_path, metadata=metadata, extra_info=extra_info) + + def load( + self, + file_path: Union[Path, str], + metadata: bool = True, + extra_info: Optional[Dict] = None, + ) -> List[Document]: + """Loads list of documents from Html file and also accepts extra information in dict format. + + Args: + file_path (Union[Path, str]): file path of Html file (accepts string or Path). + metadata (bool, optional): if metadata to be included or not. Defaults to True. + extra_info (Optional[Dict], optional): extra information related to each document in dict format. Defaults to None. + + Raises: + TypeError: if extra_info is not a dictionary. + TypeError: if file_path is not a string or Path. + + Returns: + List[Document]: list of documents. + """ + + md_content = self.convert_html_to_markdown(file_path) + logger.info(f"[PaiHtmlReader] successfully processed html file {file_path}.") + docs = [] + if metadata and extra_info: + extra_info = extra_info + else: + extra_info = dict() + logger.info(f"processed html file {file_path} without metadata") + doc = Document(text=md_content, extra_info=extra_info) + docs.append(doc) + print(f"[PaiHtmlReader] successfully loaded {len(docs)} nodes.") + return docs diff --git a/src/pai_rag/utils/markdown_utils.py b/src/pai_rag/utils/markdown_utils.py index ea4aae57..3a5db69a 100644 --- a/src/pai_rag/utils/markdown_utils.py +++ b/src/pai_rag/utils/markdown_utils.py @@ -9,15 +9,51 @@ class PaiTable(BaseModel): - headers: Optional[List[List[str]]] = ( - Field(description="The table headers.", default=None), + data: List[List[str]] = (Field(description="The table data.", default=[]),) + row_headers_index: Optional[List[int]] = ( + Field(description="The table row headers index.", default=None), ) - rows: Optional[List[List[str]]] = Field(description="The table rows.", default=None) + column_headers_index: Optional[List[int]] = ( + Field(description="The table column headers index.", default=None), + ) + + def get_row_numbers(self): + return len(self.data) + + def get_col_numbers(self): + return len(self.data[0]) + + def get_row_headers(self): + if len(self.row_headers_index) == 0: + return [] + return [self.data[row] for row in self.row_headers_index] + + def get_rows(self): + if self.row_headers_index: + data_row_start_index = max(self.row_headers_index) + 1 + else: + data_row_start_index = 0 + return self.data[data_row_start_index:] + + def get_column_headers(self): + if len(self.column_headers_index) == 0: + return [] + return [[row[i] for i in self.column_headers_index] for row in self.data] + + def get_columns(self): + if self.column_headers_index: + data_col_start_index = max(self.col_headers_index) + 1 + else: + data_col_start_index = 0 + return [ + [row[i] for i in range(data_col_start_index, self.get_col_numbers())] + for row in self.data + ] def transform_local_to_oss(oss_cache: Any, image: PngImageFile, doc_name: str) -> str: try: - if image.mode == "RGBA": + if image.mode != "RGB": image = image.convert("RGB") if image.width <= 50 or image.height <= 50: return None @@ -73,11 +109,17 @@ def _table_to_markdown(self, table, doc_name): def convert_table_to_markdown(table: PaiTable, total_cols: int) -> str: markdown = [] - if table.headers: - for header in table.headers: + if len(table.get_column_headers()) > 0: + headers = table.get_column_headers() + rows = table.get_columns() + else: + headers = table.get_row_headers() + rows = table.get_rows() + if headers: + for header in headers: markdown.append("| " + " | ".join(header) + " |") - markdown.append("| " + " | ".join(["---"] * total_cols) + " |") - if table.rows: - for row in table.rows: + markdown.append("| " + " | ".join(["---"] * total_cols) + " |") + if rows: + for row in rows: markdown.append("| " + " | ".join(row) + " |") return "\n".join(markdown)