Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nodes enhancement by raptor #111

Merged
merged 23 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 107 additions & 14 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ markdown = "^3.6"
chardet = "^5.2.0"
locust = "^2.29.0"
gunicorn = "^22.0.0"
umap-learn = "^0.5.6"
protobuf = "3.20.0"
modelscope = "^1.16.0"
llama-index-multi-modal-llms-dashscope = "^0.1.2"
Expand Down
1 change: 1 addition & 0 deletions pyproject_gpu.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ markdown = "^3.6"
chardet = "^5.2.0"
locust = "^2.29.0"
gunicorn = "^22.0.0"
umap-learn = "^0.5.6"
protobuf = "3.20.0"
modelscope = "^1.16.0"
llama-index-multi-modal-llms-dashscope = "^0.1.2"
Expand Down
1 change: 1 addition & 0 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@ class RagResponse(BaseModel):
class DataInput(BaseModel):
file_path: str
enable_qa_extraction: bool = False
enable_raptor: bool = False
2 changes: 2 additions & 0 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ async def generate_qa_dataset(overwrite: bool = False):
async def upload_data(
files: List[UploadFile],
faiss_path: str = Form(None),
enable_raptor: bool = Form(False),
background_tasks: BackgroundTasks = BackgroundTasks(),
):
task_id = uuid.uuid4().hex
Expand All @@ -142,6 +143,7 @@ async def upload_data(
filter_pattern=None,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
)

return {"task_id": task_id}
11 changes: 8 additions & 3 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,23 @@ def query_vector(self, text: str):
response["answer"] = formatted_text
return response

def add_knowledge(self, input_files: str, enable_qa_extraction: bool):
def add_knowledge(
self, input_files: str, enable_qa_extraction: bool, enable_raptor: bool
):
files = []
file_obj_list = []
for file_name in input_files:
file_obj = open(file_name, "rb")
mimetype = mimetypes.guess_type(file_name)[0]
files.append(("files", (os.path.basename(file_name), file_obj, mimetype)))
file_obj_list.append(file_obj)

para = {"enable_raptor": enable_raptor}
try:
r = requests.post(
self.load_data_url, files=files, timeout=DEFAULT_CLIENT_TIME_OUT
self.load_data_url,
files=files,
data=para,
timeout=DEFAULT_CLIENT_TIME_OUT,
)
response = dotdict(json.loads(r.text))
if r.status_code != HTTPStatus.OK:
Expand Down
22 changes: 14 additions & 8 deletions src/pai_rag/app/web/tabs/upload_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import asyncio


def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction):
def upload_knowledge(
upload_files, chunk_size, chunk_overlap, enable_qa_extraction, enable_raptor
):
if not upload_files:
return

Expand All @@ -28,13 +30,9 @@ def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extracti
),
]

try:
response = rag_client.add_knowledge(
[file.name for file in upload_files], enable_qa_extraction
)
except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")

response = rag_client.add_knowledge(
[file.name for file in upload_files], enable_qa_extraction, enable_raptor
)
my_upload_files = []
for file in upload_files:
my_upload_files.append(
Expand Down Expand Up @@ -94,6 +92,11 @@ def create_upload_tab() -> Dict[str, Any]:
info="Process with QA Extraction Model",
elem_id="enable_qa_extraction",
)
enable_raptor = gr.Checkbox(
label="Yes",
info="Process with Raptor Node Enhancement",
elem_id="enable_raptor",
)
with gr.Column(scale=8):
with gr.Tab("Files"):
upload_file = gr.File(
Expand Down Expand Up @@ -121,6 +124,7 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_size,
chunk_overlap,
enable_qa_extraction,
enable_raptor,
],
outputs=[upload_file_state_df, upload_file_state],
api_name="upload_knowledge",
Expand All @@ -132,6 +136,7 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_size,
chunk_overlap,
enable_qa_extraction,
enable_raptor,
],
outputs=[upload_dir_state_df, upload_dir_state],
api_name="upload_knowledge_dir",
Expand All @@ -140,4 +145,5 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_size.elem_id: chunk_size,
chunk_overlap.elem_id: chunk_overlap,
enable_qa_extraction.elem_id: enable_qa_extraction,
enable_raptor.elem_id: enable_raptor,
}
6 changes: 6 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ViewModel(BaseModel):
# reader
reader_type: str = "SimpleDirectoryReader"
enable_qa_extraction: bool = False
enable_raptor: bool = False

config_file: str = None

Expand Down Expand Up @@ -203,6 +204,9 @@ def from_app_config(config):
view_model.enable_qa_extraction = config["data_reader"].get(
"enable_qa_extraction", view_model.enable_qa_extraction
)
view_model.enable_raptor = config["data_reader"].get(
"enable_raptor", view_model.enable_raptor
)

view_model.similarity_top_k = config["retriever"].get("similarity_top_k", 5)
if config["retriever"]["retrieval_mode"] == "hybrid":
Expand Down Expand Up @@ -269,6 +273,7 @@ def to_app_config(self):
config["node_parser"]["chunk_overlap"] = int(self.chunk_overlap)

config["data_reader"]["enable_qa_extraction"] = self.enable_qa_extraction
config["data_reader"]["enable_raptor"] = self.enable_raptor
config["data_reader"]["type"] = self.reader_type

if self.vectordb_type == "Hologres":
Expand Down Expand Up @@ -440,6 +445,7 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
settings["chunk_size"] = {"value": self.chunk_size}
settings["chunk_overlap"] = {"value": self.chunk_overlap}
settings["enable_qa_extraction"] = {"value": self.enable_qa_extraction}
settings["enable_raptor"] = {"value": self.enable_raptor}
settings["similarity_top_k"] = {"value": self.similarity_top_k}
settings["rerank_model"] = {"value": self.rerank_model}
settings["retrieval_mode"] = {"value": self.retrieval_mode}
Expand Down
7 changes: 6 additions & 1 deletion src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ embed_batch_size = 10
[rag.evaluation]
retrieval = ["mrr", "hit_rate"]
response_label = true
response = ["Faithfulness", "Answer Relevancy", "Guideline Adherence", "Correctness", "Semantic Similarity"]
response = ["Faithfulness", "Answer Relevancy", "Correctness", "Semantic Similarity"]
moria97 marked this conversation as resolved.
Show resolved Hide resolved

[rag.index]
persist_path = "localdata/storage"
Expand All @@ -53,6 +53,11 @@ type = "SimpleChatEngine"
[rag.multi_modal_llm]
source = ""

[rag.node_enhancement]
moria97 marked this conversation as resolved.
Show resolved Hide resolved
tree_depth = 3
max_clusters = 52
proba_threshold = 0.10

[rag.node_parser]
type = "Sentence"
chunk_size = 500
Expand Down
5 changes: 4 additions & 1 deletion src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def aload_knowledge(
filter_pattern=None,
faiss_path=None,
enable_qa_extraction=False,
enable_raptor=False,
):
sessioned_config = self.config
if faiss_path:
Expand All @@ -50,7 +51,9 @@ async def aload_knowledge(
data_loader = module_registry.get_module_with_config(
"DataLoaderModule", sessioned_config
)
await data_loader.aload(input_files, filter_pattern, enable_qa_extraction)
await data_loader.aload(
input_files, filter_pattern, enable_qa_extraction, enable_raptor
)

async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
Expand Down
7 changes: 6 additions & 1 deletion src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ async def add_knowledge_async(
filter_pattern: str = None,
faiss_path: str = None,
enable_qa_extraction: bool = False,
enable_raptor: bool = False,
):
self.check_updates()
with open(TASK_STATUS_FILE, "a") as f:
f.write(f"{task_id}\tprocessing\n")
try:
await self.rag.aload_knowledge(
input_files, filter_pattern, faiss_path, enable_qa_extraction
input_files,
filter_pattern,
faiss_path,
enable_qa_extraction,
enable_raptor,
)
with open(TASK_STATUS_FILE, "a") as f:
f.write(f"{task_id}\tcompleted\n")
Expand Down
Loading
Loading