Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to support upload local files #67

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,26 @@ curl -X 'POST' http://127.0.0.1:8000/service/batch_evaluate/response
}
```

3. 上传

支持通过API的方式上传本地文件,并支持指定不同的faiss_path,每次发送API请求会返回一个task_id,之后可以通过task_id来查看文件上传状态(processing、completed、failed)。

- **(1)上传(upload_local_data)**

```bash
curl -X 'POST' http://127.0.0.1:8000/service/upload_local_data -H 'Content-Type: multipart/form-data' -F 'file=@local_path/PAI.txt' -F 'faiss_path=localdata/storage'

# Return: {"task_id": "2c1e557733764fdb9fefa063538914da"}
```

- **(2)查看上传状态(upload_local_data)**

```bash
curl http://127.0.0.1:8077/service/get_upload_state\?task_id\=2c1e557733764fdb9fefa063538914da

# Return: {"task_id":"2c1e557733764fdb9fefa063538914da","status":"completed"}
```

### 独立脚本文件:不依赖于整体服务的启动,可独立运行

1. 向当前索引存储中插入新文件
Expand Down
33 changes: 32 additions & 1 deletion src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any
from fastapi import APIRouter, Body, BackgroundTasks
from fastapi import APIRouter, Body, BackgroundTasks, File, UploadFile, Form
import uuid
import os
import tempfile
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api.models import (
RagQuery,
Expand Down Expand Up @@ -86,3 +88,32 @@ async def batch_evaluate():
type="all"
)
return {"status": 200, "result": eval_results}


@router.post("/upload_local_data")
async def upload_local_data(
file: UploadFile = File(),
faiss_path: str = Form(),
background_tasks: BackgroundTasks = BackgroundTasks(),
):
task_id = uuid.uuid4().hex
if not file:
return {"message": "No upload file sent"}
else:
fn = file.filename
tmpdir = tempfile.mkdtemp()
save_file = os.path.join(tmpdir, f"{task_id}_{fn}")
with open(save_file, "wb") as f:
data = await file.read()
f.write(data)
f.close()

background_tasks.add_task(
rag_service.add_knowledge_async,
task_id=task_id,
file_dir=tmpdir,
faiss_path=faiss_path,
enable_qa_extraction=False,
)

return {"task_id": task_id}
12 changes: 10 additions & 2 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,17 @@ async def aload_knowledge(self, file_dir, enable_qa_extraction=False):
)
await data_loader.aload(file_dir, enable_qa_extraction)

def load_knowledge(self, file_dir, enable_qa_extraction=False):
def load_knowledge(self, file_dir, faiss_path=None, enable_qa_extraction=False):
sessioned_config = self.config
if faiss_path:
sessioned_config = self.config.copy()
sessioned_config.index.update({"persist_path": faiss_path})
self.logger.info(
f"Update rag_application config with faiss_persist_path: {faiss_path}"
)

data_loader = module_registry.get_module_with_config(
"DataLoaderModule", self.config
"DataLoaderModule", sessioned_config
)
data_loader.load(file_dir, enable_qa_extraction)

Expand Down
8 changes: 6 additions & 2 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ def reload(self, new_config: Any):
self.rag_configuration.persist()

def add_knowledge_async(
self, task_id: str, file_dir: str, enable_qa_extraction: bool = False
self,
task_id: str,
file_dir: str,
faiss_path: str = None,
enable_qa_extraction: bool = False,
):
self.tasks_status[task_id] = "processing"
try:
self.rag.load_knowledge(file_dir, enable_qa_extraction)
self.rag.load_knowledge(file_dir, faiss_path, enable_qa_extraction)
self.tasks_status[task_id] = "completed"
except Exception as ex:
logger.error(f"Upload failed: {ex}")
Expand Down
Loading