Skip to content

Commit

Permalink
Md parser (#238)
Browse files Browse the repository at this point in the history
* pdf_reader & md_parser

* pdf_reader & md_parser

* pdf_reader & md_parser

* pdf_reader & md_parser

* pdf_reader & md_parser

* pdf_reader & md_parser

* pdf_reader & md_parser
  • Loading branch information
Ceceliachenen authored Oct 8, 2024
1 parent 1de95e9 commit 733613b
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 343 deletions.
13 changes: 6 additions & 7 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ class ViewModel(BaseModel):
oss_endpoint: str = None
oss_bucket: str = None

# chunking
# node_parser
parser_type: str = "Sentence"
chunk_size: int = 500
chunk_overlap: int = 20
enable_multimodal: bool = False

# reader
reader_type: str = "SimpleDirectoryReader"
enable_qa_extraction: bool = False
enable_raptor: bool = False
enable_multimodal: bool = False
enable_table_summary: bool = False

config_file: str = None
Expand Down Expand Up @@ -326,6 +326,9 @@ def from_app_config(config):
view_model.parser_type = config["node_parser"]["type"]
view_model.chunk_size = config["node_parser"]["chunk_size"]
view_model.chunk_overlap = config["node_parser"]["chunk_overlap"]
view_model.enable_multimodal = config["node_parser"].get(
"enable_multimodal", view_model.enable_multimodal
)

view_model.reader_type = config["data_reader"].get(
"type", view_model.reader_type
Expand All @@ -336,9 +339,6 @@ def from_app_config(config):
view_model.enable_raptor = config["data_reader"].get(
"enable_raptor", view_model.enable_raptor
)
view_model.enable_multimodal = config["data_reader"].get(
"enable_multimodal", view_model.enable_multimodal
)
view_model.enable_table_summary = config["data_reader"].get(
"enable_table_summary", view_model.enable_table_summary
)
Expand Down Expand Up @@ -481,10 +481,10 @@ def to_app_config(self):
config["node_parser"]["type"] = self.parser_type
config["node_parser"]["chunk_size"] = int(self.chunk_size)
config["node_parser"]["chunk_overlap"] = int(self.chunk_overlap)
config["node_parser"]["enable_multimodal"] = self.enable_multimodal

config["data_reader"]["enable_qa_extraction"] = self.enable_qa_extraction
config["data_reader"]["enable_raptor"] = self.enable_raptor
config["data_reader"]["enable_multimodal"] = self.enable_multimodal
config["data_reader"]["enable_table_summary"] = self.enable_table_summary
config["data_reader"]["type"] = self.reader_type

Expand Down Expand Up @@ -611,7 +611,6 @@ def to_app_config(self):
config["search"]["search_lang"] = self.search_lang
config["search"]["search_count"] = self.search_count

print(config)
return _transform_to_dict(config)

def get_local_generated_qa_file(self):
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ type = "local"

[rag.data_reader]
type = "SimpleDirectoryReader"
enable_multimodal = true

# embedding configurations, source support API: OpenAI,DashScope; and local model:HuggingFace
# if use API, need set OPENAI_API_KEY or DASHSCOPE_API_KEY in ENV, If HuggingFace, need set model_name
Expand Down Expand Up @@ -82,6 +81,7 @@ proba_threshold = 0.10
type = "Sentence"
chunk_size = 500
chunk_overlap = 10
enable_multimodal = true

[rag.oss_store]
bucket = ""
Expand Down
112 changes: 0 additions & 112 deletions src/pai_rag/config/settings_multi_modal.toml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
from typing import Any, Dict, List, Optional
import json

from llama_index.core.base.base_multi_modal_retriever import (
MultiModalRetriever,
Expand Down Expand Up @@ -185,7 +186,7 @@ def _retrieve(
# 从文本中召回图片
if self._search_image and len(image_nodes) < self._image_similarity_top_k:
for node in text_nodes:
image_urls = node.node.metadata.get("image_url")
image_urls = json.loads(node.node.metadata.get("image_url_list_str"))
if not image_urls:
continue
for image_url in image_urls:
Expand Down Expand Up @@ -454,14 +455,14 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
task_results = await asyncio.gather(*tasks)

text_nodes, image_nodes = task_results[0], task_results[1]
logger.info(f"Retrieved text nodes: {text_nodes}")
logger.info(f"Retrieved image nodes: {image_nodes}")
logger.info(f"Retrieved text nodes: {len(text_nodes)}")
logger.info(f"Retrieved image nodes: {len(image_nodes)}")

seen_images = set([node.node.image_url for node in image_nodes])
# 从文本中召回图片
if self._search_image and len(image_nodes) < self._image_similarity_top_k:
for node in text_nodes:
image_urls = node.node.metadata.get("image_url")
image_urls = json.loads(node.node.metadata.get("image_url_list_str"))
if not image_urls:
continue
for image_url in image_urls:
Expand All @@ -482,6 +483,7 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
if not image_nodes:
image_nodes = []
results = text_nodes + image_nodes

return results

async def _atext_retrieve(
Expand Down
Loading

0 comments on commit 733613b

Please sign in to comment.