Skip to content

Commit

Permalink
Refactor for multimodal (#232)
Browse files Browse the repository at this point in the history
* Refactor

* Fix mm embedding

* Fix node id bug

* Fix retriever

* Add faiss debug

* Fix reranker

* Fix tests

* Fix oss upload

* Update

* Fix milvus weights

* Fix opensearch multithreading

* Fix UI
  • Loading branch information
moria97 authored Sep 25, 2024
1 parent b511141 commit 19db02e
Show file tree
Hide file tree
Showing 82 changed files with 3,253 additions and 4,081 deletions.
12 changes: 0 additions & 12 deletions docs/config_guide_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,6 @@ type = "RetrieverQueryEngine"

查询引擎(query engine)是一个通用接口,接收自然语言查询,并返回丰富的响应。

## rag.llm_chat_engine

type = "SimpleChatEngine"

基于 query engine 之上的一个高级接口,用于与数据进行对话(而不是单一的问答),可类比为状态化的查询引擎。

## rag.chat_engine

type = [CondenseQuestionChatEngine]

Condense question 是建立在查询引擎query engine之上的简易聊天模式。每次聊天交互中:首先从对话上下文和最后一条消息生成一个独立的问题,然后用这个简化的问题查询查询引擎以获取回复。

## rag.chat_store

type = [Local, Aliyun-Redis]
Expand Down
13 changes: 0 additions & 13 deletions docs/config_guide_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,19 +256,6 @@ type = "RetrieverQueryEngine"

Query engine is a generic interface that allows you to ask question over your data. It takes in a natural language query, and returns a rich response.

## rag.llm_chat_engine

type = "SimpleChatEngine"

Chat engine is a high-level interface for having a conversation with your data (multiple back-and-forth instead of a single question & answer). It is a stateful analogy of a query engine.

## rag.chat_engine

type = [CondenseQuestionChatEngine]

Condense question is a simple chat mode built on top of a query engine over your data. For each chat interaction:
first generate a standalone question from conversation context and last message, then query the query engine with the condensed question for a response.

## rag.chat_store

type = [Local, Aliyun-Redis]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ type = "single"
[rag.agent.tool]
type = "api"

[rag.chat_engine]
type = "CondenseQuestionChatEngine"

[rag.chat_store]
type = "Local" # [Local, Aliyun-Redis]
host = "Aliyun-Redis host"
Expand Down Expand Up @@ -66,9 +63,6 @@ name = "qwen2-7b-instruct"
[rag.llm.multi_modal]
source = ""

[rag.llm_chat_engine]
type = "SimpleChatEngine"

[rag.node_enhancement]
tree_depth = 3
max_clusters = 52
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ type = "single"
[rag.agent.tool]
type = "python"

[rag.chat_engine]
type = "CondenseQuestionChatEngine"

[rag.chat_store]
type = "Local" # [Local, Aliyun-Redis]
host = "Aliyun-Redis host"
Expand Down Expand Up @@ -66,9 +63,6 @@ name = "qwen2-7b-instruct"
[rag.llm.multi_modal]
source = ""

[rag.llm_chat_engine]
type = "SimpleChatEngine"

[rag.node_enhancement]
tree_depth = 3
max_clusters = 52
Expand Down
98 changes: 49 additions & 49 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from fastapi.responses import StreamingResponse
import logging

from pai_rag.integrations.nodeparsers.pai.pai_node_parser import (
COMMON_FILE_PATH_FODER_NAME,
)

logger = logging.getLogger(__name__)

router = APIRouter()
Expand Down Expand Up @@ -128,61 +132,57 @@ async def generate_qa_dataset(overwrite: bool = False):

@router.post("/upload_data")
async def upload_data(
files: List[UploadFile],
files: List[UploadFile] = Body(None),
oss_path: str = Form(None),
faiss_path: str = Form(None),
enable_raptor: bool = Form(False),
background_tasks: BackgroundTasks = BackgroundTasks(),
):
task_id = uuid.uuid4().hex
if not files:
return {"message": "No upload file sent"}

tmpdir = tempfile.mkdtemp()
input_files = []
for file in files:
fn = file.filename
data = await file.read()
file_hash = hashlib.md5(data).hexdigest()
save_file = os.path.join(tmpdir, f"{file_hash}_{fn}")

with open(save_file, "wb") as f:
f.write(data)
f.close()
input_files.append(save_file)

background_tasks.add_task(
rag_service.add_knowledge,
task_id=task_id,
input_files=input_files,
filter_pattern=None,
oss_prefix=None,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
)

return {"task_id": task_id}


@router.post("/upload_data_from_oss")
async def upload_oss_data(
oss_prefix: str = None,
faiss_path: str = None,
enable_raptor: bool = False,
background_tasks: BackgroundTasks = BackgroundTasks(),
):
task_id = uuid.uuid4().hex
background_tasks.add_task(
rag_service.add_knowledge,
task_id=task_id,
input_files=None,
filter_pattern=None,
oss_prefix=oss_prefix,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
from_oss=True,
)
if oss_path:
background_tasks.add_task(
rag_service.add_knowledge,
task_id=task_id,
filter_pattern=None,
oss_path=oss_path,
from_oss=True,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
)
else:
if not files:
return {"message": "No upload file sent"}

tmpdir = tempfile.mkdtemp()
input_files = []
for file in files:
fn = file.filename
data = await file.read()
file_hash = hashlib.md5(data).hexdigest()
tmp_file_dir = os.path.join(
tmpdir, f"{COMMON_FILE_PATH_FODER_NAME}/{file_hash}"
)
os.makedirs(tmp_file_dir, exist_ok=True)
save_file = os.path.join(tmp_file_dir, fn)

with open(save_file, "wb") as f:
f.write(data)
f.close()
input_files.append(save_file)

background_tasks.add_task(
rag_service.add_knowledge,
task_id=task_id,
input_files=input_files,
filter_pattern=None,
oss_path=None,
faiss_path=faiss_path,
enable_qa_extraction=False,
enable_raptor=enable_raptor,
temp_file_dir=tmpdir,
)

return {"task_id": task_id}

Expand Down
14 changes: 7 additions & 7 deletions src/pai_rag/app/web/event_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ def change_vectordb_conn(vectordb_type):
milvus_visible = False
opensearch_visible = False
postgresql_visible = False
if vectordb_type == "AnalyticDB":
if vectordb_type.lower() == "analyticdb":
adb_visible = True
elif vectordb_type == "Hologres":
elif vectordb_type.lower() == "hologres":
hologres_visible = True
elif vectordb_type == "ElasticSearch":
elif vectordb_type.lower() == "elasticsearch":
es_visible = True
elif vectordb_type == "Milvus":
elif vectordb_type.lower() == "milvus":
milvus_visible = True
elif vectordb_type == "FAISS":
elif vectordb_type.lower() == "faiss":
faiss_visible = True
elif vectordb_type == "OpenSearch":
elif vectordb_type.lower() == "opensearch":
opensearch_visible = True
elif vectordb_type == "PostgreSQL":
elif vectordb_type.lower() == "postgresql":
postgresql_visible = True

return [
Expand Down
23 changes: 14 additions & 9 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,18 +338,23 @@ def query_vector(self, text: str):

def add_knowledge(
self,
input_files: str,
enable_qa_extraction: bool,
enable_raptor: bool,
oss_path: str = None,
input_files: str = None,
enable_qa_extraction: bool = False,
enable_raptor: bool = False,
):
files = []
file_obj_list = []
for file_name in input_files:
file_obj = open(file_name, "rb")
mimetype = mimetypes.guess_type(file_name)[0]
files.append(("files", (os.path.basename(file_name), file_obj, mimetype)))
file_obj_list.append(file_obj)
para = {"enable_raptor": enable_raptor}
if input_files:
for file_name in input_files:
file_obj = open(file_name, "rb")
mimetype = mimetypes.guess_type(file_name)[0]
files.append(
("files", (os.path.basename(file_name), file_obj, mimetype))
)
file_obj_list.append(file_obj)

para = {"enable_raptor": enable_raptor, "oss_path": oss_path}
try:
r = requests.post(
self.load_data_url,
Expand Down
Loading

0 comments on commit 19db02e

Please sign in to comment.