Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Personal/ranxia/html reader #255

Merged
merged 7 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ class NodeParserConfig(BaseModel):
buffer_size: int = DEFAULT_BUFFER_SIZE


DOC_TYPES_DO_NOT_NEED_CHUNKING = set(
[".csv", ".xlsx", ".xls", ".htm", ".html", ".jsonl"]
)
DOC_TYPES_CONVERT_TO_MD = set([".md", ".pdf", ".docx"])
DOC_TYPES_DO_NOT_NEED_CHUNKING = set([".csv", ".xlsx", ".xls", ".jsonl"])
DOC_TYPES_CONVERT_TO_MD = set([".md", ".pdf", ".docx", ".htm", ".html"])
IMAGE_FILE_TYPES = set([".jpg", ".jpeg", ".png"])

IMAGE_URL_REGEX = re.compile(
Expand Down
12 changes: 9 additions & 3 deletions src/pai_rag/integrations/readers/pai/pai_data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pai_rag.integrations.readers.markdown_reader import MarkdownReader
from pai_rag.integrations.readers.pai_image_reader import PaiImageReader
from pai_rag.integrations.readers.pai_pdf_reader import PaiPDFReader
from pai_rag.integrations.readers.html.html_reader import HtmlReader
from pai_rag.integrations.readers.pai_html_reader import PaiHtmlReader
from pai_rag.integrations.readers.pai_csv_reader import PaiPandasCSVReader
from pai_rag.integrations.readers.pai_excel_reader import PaiPandasExcelReader
from pai_rag.integrations.readers.pai_jsonl_reader import PaiJsonLReader
Expand Down Expand Up @@ -32,8 +32,14 @@ def get_file_readers(reader_config: BaseDataReaderConfig = None, oss_store: Any
image_reader = PaiImageReader(oss_cache=oss_store)

file_readers = {
".html": HtmlReader(),
".htm": HtmlReader(),
".html": PaiHtmlReader(
enable_table_summary=reader_config.enable_table_summary,
oss_cache=oss_store, # Storing html images
),
".htm": PaiHtmlReader(
enable_table_summary=reader_config.enable_table_summary,
oss_cache=oss_store, # Storing html images
),
".docx": PaiDocxReader(
enable_table_summary=reader_config.enable_table_summary,
oss_cache=oss_store, # Storing docx images
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/integrations/readers/pai_docx_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def load_data(
metadata: bool = True,
extra_info: Optional[Dict] = None,
) -> List[Document]:
"""Loads list of documents from PDF file and also accepts extra information in dict format."""
"""Loads list of documents from Docx file and also accepts extra information in dict format."""
return self.load(file_path, metadata=metadata, extra_info=extra_info)

def load(
Expand Down
238 changes: 238 additions & 0 deletions src/pai_rag/integrations/readers/pai_html_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
"""Html parser.

"""
import html2text
import logging
from bs4 import BeautifulSoup
import requests
from typing import Dict, List, Optional, Union, Any
from io import BytesIO
from pai_rag.utils.markdown_utils import (
transform_local_to_oss,
convert_table_to_markdown,
PaiTable,
)
from pathlib import Path
import re
import time
import os
from PIL import Image
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document

logger = logging.getLogger(__name__)

IMAGE_URL_PATTERN = (
r"!\[(?P<alt_text>.*?)\]\((https?://[^\s]+?[\s\w.-]*\.(jpg|jpeg|png|gif|bmp))\)"
)


class PaiHtmlReader(BaseReader):
"""Read html files including texts, tables, images.

Args:
enable_table_summary (bool): whether to use table_summary to process tables
"""

def __init__(
self,
enable_table_summary: bool = False,
oss_cache: Any = None,
) -> None:
self.enable_table_summary = enable_table_summary
self._oss_cache = oss_cache
logger.info(
f"PaiHtmlReader created with enable_table_summary : {self.enable_table_summary}"
)

def _extract_tables(self, html):
soup = BeautifulSoup(html, "html.parser")
tables = soup.find_all("table")
for table in tables:
# 替换表格内容为一个占位符
placeholder = f"<!-- TABLE_PLACEHOLDER_{id(table)} -->"
table.replace_with(placeholder)
return str(soup), tables

def _convert_table_to_pai_table(self, table):
# 标记header的index
row_headers_index = []
col_headers_index = []
row_header_flag = True
col_header_index_max = -1
table_matrix = []
current_row_index = 0
max_cols = 0
max_rows = 0
for row in table.find_all("tr"):
current_col_index = 0
if current_row_index == 0:
row_cells = []
else:
row_cells = [""] * max_cols
if current_row_index >= max_rows:
table_matrix.append(row_cells)
for cell in row.find_all(["th", "td"]):
if cell.name != "th":
row_header_flag = False
else:
col_header_index_max = max(col_header_index_max, current_col_index)
cell_content = self._parse_cell_content(cell)
col_span = int(cell.get("colspan", 1))
row_span = int(cell.get("rowspan", 1))
if current_row_index != 0:
while (
current_col_index < max_cols
and table_matrix[current_row_index][current_col_index] != ""
):
current_col_index += 1
if (current_col_index > max_cols and max_cols != 0) or (
current_row_index > max_rows and max_rows != 0
):
break
for i in range(col_span):
if current_row_index == 0:
table_matrix[current_row_index].append(cell_content)
elif current_col_index + i < max_cols:
table_matrix[current_row_index][
current_col_index + i
] = cell_content
for i in range(1, row_span):
if current_row_index + i > max_rows:
table_matrix.append(row_cells)
table_matrix[current_row_index + i][
current_col_index
] = cell_content
max_rows = max(current_row_index + row_span, max_rows)
current_col_index += col_span
if current_row_index == 0:
max_cols += col_span
if row_header_flag:
row_headers_index.append(current_row_index)
current_row_index += 1

for i in range(col_header_index_max + 1):
col_headers_index.append(i)

table = PaiTable(
data=table_matrix,
row_headers_index=row_headers_index,
column_headers_index=col_headers_index,
)

return table, max_cols

def _parse_cell_content(self, cell):
content = []
for element in cell.contents:
images = element.find_all("img")
image_links = [img.get("src") for img in images]
for image_url in image_links:
content.append(f"![]({image_url})")
content.append(element.text)
return " ".join(content)

def _convert_table_to_markdown(self, table):
table, total_cols = self._convert_table_to_pai_table(table)
return convert_table_to_markdown(table, total_cols)

def _transform_local_to_oss(self, html_name: str, image_url: str):
response = requests.get(image_url)
response.raise_for_status() # 检查请求是否成功

# 将二进制数据转换为图像对象
image = Image.open(BytesIO(response.content))
return transform_local_to_oss(self._oss_cache, image, html_name)

def _replace_image_paths(self, html_name: str, content: str):
image_pattern = IMAGE_URL_PATTERN
matches = re.findall(image_pattern, content)
for alt_text, image_url, image_type in matches:
time_tag = int(time.time())
oss_url = self._transform_local_to_oss(html_name, image_url)
updated_alt_text = f"pai_rag_image_{time_tag}_{alt_text}"
content = content.replace(
f"![{alt_text}]({image_url})", f"![{updated_alt_text}]({oss_url})"
)

return content

def convert_html_to_markdown(self, html_path):
html_name = os.path.basename(html_path).split(".")[0]
html_name = html_name.replace(" ", "_")
try:
with open(html_path, "r", encoding="utf-8") as f:
html_content = f.read()

modified_html, tables = self._extract_tables(html_content)
h = html2text.HTML2Text()

# 配置 html2text 对象
h.ignore_links = True # 是否忽略链接
h.ignore_images = False # 是否忽略图片
h.escape_all = True # 是否转义所有特殊字符
h.body_width = 0 # 设置行宽为 0 表示不限制行宽

# 将 HTML 转换为 Markdown
markdown_content = h.handle(modified_html)
for table in tables:
table_markdown = self._convert_table_to_markdown(table) + "\n\n"
placeholder = f"<!-- TABLE_PLACEHOLDER_{id(table)} -->"
markdown_content = markdown_content.replace(placeholder, table_markdown)

markdown_content = self._replace_image_paths(html_name, markdown_content)

return markdown_content

except Exception as e:
logger(e)
return None

def load_data(
self,
file_path: Union[Path, str],
metadata: bool = True,
extra_info: Optional[Dict] = None,
) -> List[Document]:
"""Loads list of documents from Html file and also accepts extra information in dict format."""
return self.load(file_path, metadata=metadata, extra_info=extra_info)

def load(
self,
file_path: Union[Path, str],
metadata: bool = True,
extra_info: Optional[Dict] = None,
) -> List[Document]:
"""Loads list of documents from Html file and also accepts extra information in dict format.

Args:
file_path (Union[Path, str]): file path of Html file (accepts string or Path).
metadata (bool, optional): if metadata to be included or not. Defaults to True.
extra_info (Optional[Dict], optional): extra information related to each document in dict format. Defaults to None.

Raises:
TypeError: if extra_info is not a dictionary.
TypeError: if file_path is not a string or Path.

Returns:
List[Document]: list of documents.
"""

md_content = self.convert_html_to_markdown(file_path)
logger.info(f"[PaiHtmlReader] successfully processed html file {file_path}.")
docs = []
if metadata:
moria97 marked this conversation as resolved.
Show resolved Hide resolved
if not extra_info:
extra_info = {}
doc = Document(text=md_content, extra_info=extra_info)

docs.append(doc)
else:
doc = Document(
text=md_content,
extra_info=dict(),
)
docs.append(doc)
logger.info(f"processed html file {file_path} without metadata")
print(f"[PaiHtmlReader] successfully loaded {len(docs)} nodes.")
return docs
60 changes: 51 additions & 9 deletions src/pai_rag/utils/markdown_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,51 @@


class PaiTable(BaseModel):
headers: Optional[List[List[str]]] = (
Field(description="The table headers.", default=None),
data: List[List[str]] = (Field(description="The table data.", default=[]),)
row_headers_index: Optional[List[int]] = (
Field(description="The table row headers index.", default=None),
)
rows: Optional[List[List[str]]] = Field(description="The table rows.", default=None)
column_headers_index: Optional[List[int]] = (
Field(description="The table column headers index.", default=None),
)

def get_row_numbers(self):
return len(self.data)

def get_col_numbers(self):
return len(self.data[0])

def get_row_headers(self):
if len(self.row_headers_index) == 0:
return []
return [self.data[row] for row in self.row_headers_index]

def get_rows(self):
if self.row_headers_index:
data_row_start_index = max(self.row_headers_index) + 1
else:
data_row_start_index = 0
return self.data[data_row_start_index:]

def get_column_headers(self):
if len(self.column_headers_index) == 0:
return []
return [[row[i] for i in self.column_headers_index] for row in self.data]

def get_columns(self):
if self.column_headers_index:
data_col_start_index = max(self.col_headers_index) + 1
else:
data_col_start_index = 0
return [
[row[i] for i in range(data_col_start_index, self.get_col_numbers())]
for row in self.data
]


def transform_local_to_oss(oss_cache: Any, image: PngImageFile, doc_name: str) -> str:
try:
if image.mode == "RGBA":
if image.mode != "RGB":
moria97 marked this conversation as resolved.
Show resolved Hide resolved
image = image.convert("RGB")
if image.width <= 50 or image.height <= 50:
return None
Expand Down Expand Up @@ -73,11 +109,17 @@ def _table_to_markdown(self, table, doc_name):

def convert_table_to_markdown(table: PaiTable, total_cols: int) -> str:
markdown = []
if table.headers:
for header in table.headers:
if len(table.get_column_headers()) > 0:
headers = table.get_column_headers()
rows = table.get_columns()
else:
headers = table.get_row_headers()
rows = table.get_rows()
if headers:
for header in headers:
markdown.append("| " + " | ".join(header) + " |")
markdown.append("| " + " | ".join(["---"] * total_cols) + " |")
if table.rows:
for row in table.rows:
markdown.append("| " + " | ".join(["---"] * total_cols) + " |")
if rows:
for row in rows:
markdown.append("| " + " | ".join(row) + " |")
return "\n".join(markdown)
Loading