diff --git a/src/pai_rag/integrations/readers/pai_docx_reader.py b/src/pai_rag/integrations/readers/pai_docx_reader.py index 1a4ba668..06eefb18 100644 --- a/src/pai_rag/integrations/readers/pai_docx_reader.py +++ b/src/pai_rag/integrations/readers/pai_docx_reader.py @@ -9,6 +9,7 @@ from pai_rag.utils.markdown_utils import ( transform_local_to_oss, convert_table_to_markdown, + is_horizontal_table, PaiTable, ) from docx import Document as DocxDocument @@ -102,12 +103,13 @@ def _convert_list(self, paragraph, level=0): def _convert_table_to_markdown(self, table, doc_name): total_cols = max(len(row.cells) for row in table.rows) - header_row = table.rows[0] - rows = [] - headers = self._parse_row(header_row, doc_name, total_cols) - for row in table.rows[1:]: - rows.append(self._parse_row(row, doc_name, total_cols)) - table = PaiTable(headers=[headers], rows=rows) + table_matrix = [] + for row in table.rows: + table_matrix.append(self._parse_row(row, doc_name, total_cols)) + if is_horizontal_table(table_matrix): + table = PaiTable(data=table_matrix, row_headers_index=[0]) + else: + table = PaiTable(data=table_matrix, column_headers_index=[0]) return convert_table_to_markdown(table, total_cols) def _parse_row(self, row, doc_name, total_cols): diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py index 7b3fe897..20fc1262 100644 --- a/src/pai_rag/integrations/readers/pai_pdf_reader.py +++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py @@ -5,7 +5,10 @@ from llama_index.core.readers.base import BaseReader from llama_index.core.schema import Document from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter -from pai_rag.utils.markdown_utils import transform_local_to_oss +from pai_rag.utils.markdown_utils import ( + transform_local_to_oss, + is_horizontal_table, +) from bs4 import BeautifulSoup from llama_index.core import Settings @@ -143,42 +146,10 @@ def limit_table_content(table: List[List]) -> List[List]: for row in table ] - @staticmethod - def is_horizontal_table(table: List[List]) -> bool: - # if the table is empty or the first (header) of table is empty, it's not a horizontal table - if not table or not table[0]: - return False - - vertical_value_any_count = 0 - horizontal_value_any_count = 0 - vertical_value_all_count = 0 - horizontal_value_all_count = 0 - - """If it is a horizontal table, the probability that each row contains at least one number is higher than the probability that each column contains at least one number. - If it is a horizontal table with headers, the number of rows that are entirely composed of numbers will be greater than the number of columns that are entirely composed of numbers. - """ - - for row in table: - if any(isinstance(item, (int, float)) for item in row): - horizontal_value_any_count += 1 - if all(isinstance(item, (int, float)) for item in row): - horizontal_value_all_count += 1 - - for col in zip(*table): - if any(isinstance(item, (int, float)) for item in col): - vertical_value_any_count += 1 - if all(isinstance(item, (int, float)) for item in col): - vertical_value_all_count += 1 - - return ( - horizontal_value_any_count >= vertical_value_any_count - or horizontal_value_all_count > 0 >= vertical_value_all_count - ) - @staticmethod def tables_summarize(table: List[List]) -> str: table = PaiPDFReader.limit_table_content(table) - if not PaiPDFReader.is_horizontal_table(table): + if not is_horizontal_table(table): table = list(zip(*table)) table = table[:TABLE_SUMMARY_MAX_ROW_NUM] table = [row[:TABLE_SUMMARY_MAX_COL_NUM] for row in table] diff --git a/src/pai_rag/utils/markdown_utils.py b/src/pai_rag/utils/markdown_utils.py index 3a5db69a..4746a367 100644 --- a/src/pai_rag/utils/markdown_utils.py +++ b/src/pai_rag/utils/markdown_utils.py @@ -9,12 +9,12 @@ class PaiTable(BaseModel): - data: List[List[str]] = (Field(description="The table data.", default=[]),) - row_headers_index: Optional[List[int]] = ( - Field(description="The table row headers index.", default=None), + data: List[List[str]] = Field(description="The table data.", default=[]) + row_headers_index: Optional[List[int]] = Field( + description="The table row headers index.", default=None ) - column_headers_index: Optional[List[int]] = ( - Field(description="The table column headers index.", default=None), + column_headers_index: Optional[List[int]] = Field( + description="The table column headers index.", default=None ) def get_row_numbers(self): @@ -24,24 +24,24 @@ def get_col_numbers(self): return len(self.data[0]) def get_row_headers(self): - if len(self.row_headers_index) == 0: + if not self.row_headers_index or len(self.row_headers_index) == 0: return [] return [self.data[row] for row in self.row_headers_index] def get_rows(self): - if self.row_headers_index: + if self.row_headers_index and len(self.row_headers_index) > 0: data_row_start_index = max(self.row_headers_index) + 1 else: data_row_start_index = 0 return self.data[data_row_start_index:] def get_column_headers(self): - if len(self.column_headers_index) == 0: + if not self.column_headers_index or len(self.column_headers_index) == 0: return [] return [[row[i] for i in self.column_headers_index] for row in self.data] def get_columns(self): - if self.column_headers_index: + if self.column_headers_index and len(self.column_headers_index) > 0: data_col_start_index = max(self.col_headers_index) + 1 else: data_col_start_index = 0 @@ -123,3 +123,35 @@ def convert_table_to_markdown(table: PaiTable, total_cols: int) -> str: for row in rows: markdown.append("| " + " | ".join(row) + " |") return "\n".join(markdown) + + +def is_horizontal_table(table: List[List]) -> bool: + # if the table is empty or the first (header) of table is empty, it's not a horizontal table + if not table or not table[0]: + return False + + vertical_value_any_count = 0 + horizontal_value_any_count = 0 + vertical_value_all_count = 0 + horizontal_value_all_count = 0 + + """If it is a horizontal table, the probability that each row contains at least one number is higher than the probability that each column contains at least one number. + If it is a horizontal table with headers, the number of rows that are entirely composed of numbers will be greater than the number of columns that are entirely composed of numbers. + """ + + for row in table: + if any(isinstance(item, (int, float)) for item in row): + horizontal_value_any_count += 1 + if all(isinstance(item, (int, float)) for item in row): + horizontal_value_all_count += 1 + + for col in zip(*table): + if any(isinstance(item, (int, float)) for item in col): + vertical_value_any_count += 1 + if all(isinstance(item, (int, float)) for item in col): + vertical_value_all_count += 1 + + return ( + horizontal_value_any_count >= vertical_value_any_count + or horizontal_value_all_count > 0 >= vertical_value_all_count + ) diff --git a/tests/data_readers/test_pdf_reader.py b/tests/data_readers/test_pdf_reader.py index d3b90dd0..93ebbd0d 100644 --- a/tests/data_readers/test_pdf_reader.py +++ b/tests/data_readers/test_pdf_reader.py @@ -5,6 +5,7 @@ from pai_rag.integrations.readers.pai.pai_data_reader import PaiDataReader from pai_rag.integrations.readers.pai_pdf_reader import PaiPDFReader from pai_rag.utils.download_models import ModelScopeDownloader +from pai_rag.utils.markdown_utils import is_horizontal_table BASE_DIR = Path(__file__).parent.parent.parent @@ -53,7 +54,7 @@ def test_is_horizontal_table(): ["Age", 30, 25], ["City", "New York", "San Francisco"], ] - assert PaiPDFReader.is_horizontal_table(horizontal_table_1) - assert PaiPDFReader.is_horizontal_table(horizontal_table_2) - assert PaiPDFReader.is_horizontal_table(horizontal_table_3) - assert not PaiPDFReader.is_horizontal_table(vertical_table) + assert is_horizontal_table(horizontal_table_1) + assert is_horizontal_table(horizontal_table_2) + assert is_horizontal_table(horizontal_table_3) + assert not is_horizontal_table(vertical_table)