From f3af3dd449e4be9b42f214e98aabdadd589353a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=87=83=E5=A4=8F?= Date: Tue, 27 Aug 2024 19:41:36 +0800 Subject: [PATCH] 1. change image size 2. limit image number 3. fix retriever answer ui format --- src/pai_rag/app/web/rag_client.py | 5 ++++- .../query_engine/multi_modal_query_engine.py | 15 ++++++++++++++- .../integrations/readers/pai_pdf_reader.py | 15 ++++++++++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index ab4216f0..0a1ed8a9 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -5,6 +5,8 @@ import os import re import mimetypes +import markdown +import html from http import HTTPStatus from pai_rag.app.web.view_model import ViewModel from pai_rag.app.web.ui_constants import EMPTY_KNOWLEDGEBASE_MESSAGE @@ -187,6 +189,7 @@ def query_vector(self, text: str): response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=text) else: for i, doc in enumerate(response["docs"]): + html_content = markdown.markdown(doc["text"]) file_url = doc.get("metadata", {}).get("file_url", None) media_url = doc.get("metadata", {}).get("image_url", None) if media_url and isinstance(media_url, list): @@ -198,7 +201,7 @@ def query_vector(self, text: str): ) elif media_url: media_url = f"""""" - safe_html_content = doc["text"] + safe_html_content = html.escape(html_content).replace("\n", "
") if file_url: safe_html_content = ( f"""{safe_html_content}""" diff --git a/src/pai_rag/integrations/query_engine/multi_modal_query_engine.py b/src/pai_rag/integrations/query_engine/multi_modal_query_engine.py index 23d42b23..30d9b6ee 100644 --- a/src/pai_rag/integrations/query_engine/multi_modal_query_engine.py +++ b/src/pai_rag/integrations/query_engine/multi_modal_query_engine.py @@ -26,6 +26,8 @@ Response, ) +IMAGE_MAX_PIECES = 5 + if TYPE_CHECKING: from llama_index.core.indices.multi_modal import MultiModalVectorIndexRetriever @@ -42,7 +44,12 @@ def _get_image_and_text_nodes( nodes: List[NodeWithScore], ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]: image_nodes = [] + text_image_nodes = [] text_nodes = [] + image_urls = set() + for res_node in nodes: + if isinstance(res_node.node, ImageNode): + image_urls.add(res_node.node.image_url) for res_node in nodes: if isinstance(res_node.node, ImageNode): image_nodes.append(res_node) @@ -50,11 +57,13 @@ def _get_image_and_text_nodes( text_nodes.append(res_node) if res_node.node.metadata.get("image_url", None): for image_url in res_node.node.metadata["image_url"]: + if image_url in image_urls: + continue extra_info = { "image_url": image_url, "file_name": res_node.node.metadata.get("file_name", ""), } - image_nodes.append( + text_image_nodes.append( NodeWithScore( node=ImageNode( image_url=image_url, @@ -63,6 +72,10 @@ def _get_image_and_text_nodes( score=res_node.score, ) ) + image_nodes.sort(key=lambda x: x.score, reverse=True) + text_image_nodes.sort(key=lambda x: x.score, reverse=True) + image_nodes.extend(text_image_nodes) + image_nodes = image_nodes[:IMAGE_MAX_PIECES] return image_nodes, text_nodes diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py index 8c57dfa6..1671136d 100644 --- a/src/pai_rag/integrations/readers/pai_pdf_reader.py +++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py @@ -14,12 +14,12 @@ import magic_pdf.model as model_config import tempfile import re +import math import requests from PIL import Image from rapidocr_onnxruntime import RapidOCR from rapid_table import RapidTable - import logging import os from io import BytesIO @@ -29,6 +29,7 @@ logger = logging.getLogger(__name__) +IMAGE_MAX_PIXELS = 512 * 512 TABLE_SUMMARY_MAX_ROW_NUM = 5 TABLE_SUMMARY_MAX_COL_NUM = 10 TABLE_SUMMARY_MAX_CELL_TOKEN = 20 @@ -82,6 +83,18 @@ def replace_func(match): if image.width <= 50 or image.height <= 50: return None + current_pixels = image.width * image.height + + # 检查像素总数是否超过限制 + if current_pixels > IMAGE_MAX_PIXELS: + # 计算缩放比例以适应最大像素数 + scale = math.sqrt(IMAGE_MAX_PIXELS / current_pixels) + new_width = int(image.width * scale) + new_height = int(image.height * scale) + + # 调整图片大小 + image = image.resize((new_width, new_height), Image.LANCZOS) + image_stream = BytesIO() image.save(image_stream, format="jpeg")