Skip to content

Commit

Permalink
Merge pull request #454 from netease-youdao/develop_for_v1.3.1
Browse files Browse the repository at this point in the history
Develop for v1.3.1
  • Loading branch information
xixihahaliu authored Jul 26, 2024
2 parents deaaf43 + e9247c8 commit 45a624b
Show file tree
Hide file tree
Showing 8 changed files with 698 additions and 15 deletions.
23 changes: 22 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# -*- coding: utf-8 -*-
from __future__ import print_function # 确保 print 函数在 Python 2 中的行为与 Python 3 一致

def get_run_config_params():
openai_api_base = "https://api.openai.com/v1"
openai_api_key = "sk-xxxxxxx"
openai_api_model_name = "gpt-3.5-turbo-1106"
openai_api_context_length = "4096"
workers = 4
milvus_port = 19530
qanything_port = 8777
use_cpu = True
# 使用 .format() 方法格式化字符串,以兼容 Python 2
return "{},{},{},{},{},{},{}".format(openai_api_base, openai_api_key, openai_api_model_name,
openai_api_context_length, workers, milvus_port, qanything_port, use_cpu)

# 模型参数
llm_config = {
# 回答的最大token数,一般来说对于国内模型一个中文不到1个token,国外模型一个中文1.5-2个token
Expand Down Expand Up @@ -52,4 +68,9 @@
# 切割文件的相邻文本重合长度
"chunk_overlap": 0
}
#### 一般情况下,除非特殊需要,不要修改一下字段参数 ####
#### 一般情况下,除非特殊需要,不要修改一下字段参数 ####


if __name__ == "__main__":
import sys
sys.stdout.write(''.join(get_run_config_params()))
23 changes: 19 additions & 4 deletions qanything_kernel/connector/database/mysql/mysql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ def create_tables_(self):
"""
self.execute_query_(query, (), commit=True)

# 旧的File不存在reason,补上默认值:' '
# 如果存在File表,但是没有reason字段,那么添加reason字段
query = "PRAGMA table_info(File)"
result = self.execute_query_(query, (), fetch=True)
if result:
reason_exist = False
for column_info in result:
if column_info[1] == 'reason':
reason_exist = True
break
if not reason_exist:
query = "ALTER TABLE File ADD COLUMN reason VARCHAR(512) DEFAULT ' '"
self.execute_query_(query, (), commit=True)


query = """
CREATE TABLE IF NOT EXISTS Document (
docstore_id VARCHAR(64) PRIMARY KEY,
Expand Down Expand Up @@ -424,9 +439,9 @@ def update_chunk_size(self, file_id, chunk_size):
query = "UPDATE File SET chunk_size = ? WHERE file_id = ?"
self.execute_query_(query, (chunk_size, file_id), commit=True)

def update_file_status(self, file_id, status):
query = "UPDATE File SET status = ? WHERE file_id = ?"
self.execute_query_(query, (status, file_id), commit=True)
def update_file_status(self, file_id, status, reason):
query = "UPDATE File SET status = ?, reason = ? WHERE file_id = ?"
self.execute_query_(query, (status, reason, file_id), commit=True)

def from_status_to_status(self, file_ids, from_status, to_status):
file_ids_str = ','.join("'{}'".format(str(x)) for x in file_ids)
Expand All @@ -436,7 +451,7 @@ def from_status_to_status(self, file_ids, from_status, to_status):

# [文件] 获取指定知识库下面所有文件的id和名称
def get_files(self, user_id, kb_id):
query = "SELECT file_id, file_name, status, file_size, content_length, timestamp FROM File WHERE kb_id = ? AND kb_id IN (SELECT kb_id FROM KnowledgeBase WHERE user_id = ?) AND deleted = 0"
query = "SELECT file_id, file_name, status, file_size, content_length, timestamp, reason FROM File WHERE kb_id = ? AND kb_id IN (SELECT kb_id FROM KnowledgeBase WHERE user_id = ?) AND deleted = 0"
return self.execute_query_(query, (kb_id, user_id), fetch=True)

def get_file_path(self, file_id):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self, use_cpu: bool = False):

def get_embedding(self, sentences, max_length):
inputs_onnx = self._tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors=self.return_tensors)
debug_logger.info(f'embedding input shape: {inputs_onnx["input_ids"].shape}')
inputs_onnx = {k: v for k, v in inputs_onnx.items()}
start_time = time.time()
outputs_onnx = self._session.run(output_names=['output'], input_feed=inputs_onnx)
Expand Down
30 changes: 25 additions & 5 deletions qanything_kernel/core/local_doc_qa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter

from qanything_kernel.configs.model_config import VECTOR_SEARCH_TOP_K, CHUNK_SIZE, VECTOR_SEARCH_SCORE_THRESHOLD, \
PROMPT_TEMPLATE, STREAMING, OCR_MODEL_PATH
from typing import List
Expand All @@ -11,13 +13,14 @@
from qanything_kernel.utils.custom_log import debug_logger, qa_logger
from qanything_kernel.core.tools.web_search_tool import duckduckgo_search
from qanything_kernel.dependent_server.ocr_server.ocr import OCRQAnything
from qanything_kernel.utils.general_utils import num_tokens
from qanything_kernel.utils.general_utils import num_tokens, get_time
from .local_file import LocalFile
import traceback
import base64
import numpy as np
import platform
import cv2
import re


class LocalDocQA:
Expand All @@ -34,6 +37,12 @@ def __init__(self):
self.mode: str = None
self.use_cpu: bool = True
self.model: str = None
self.web_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", "。", "!", "!", "?", "?", ";", ";", "……", "…", "、", ",", ",", " ", ""],
chunk_size=800,
chunk_overlap=200,
length_function=num_tokens,
)

def get_ocr_result(self, input: dict):
img_file = input['img64']
Expand Down Expand Up @@ -94,17 +103,21 @@ async def insert_files_to_faiss(self, user_id, kb_id, local_files: List[LocalFil
except Exception as e:
error_info = f'split error: {traceback.format_exc()}'
debug_logger.error(error_info)
self.mysql_client.update_file_status(local_file.file_id, status='red')
self.mysql_client.update_file_status(local_file.file_id, status='red', reason='split或embedding失败,请检查文件类型,仅支持[md,txt,pdf,jpg,png,jpeg,docx,xlsx,pptx,eml,csv]')
failed_list.append(local_file)
continue
if len(local_file.docs) == 0:
self.mysql_client.update_file_status(local_file.file_id, status='red', reason='上传文件内容为空,请检查文件内容')
debug_logger.info(f'上传文件内容为空,请检查文件内容')
continue
end = time.time()
self.mysql_client.update_content_length(local_file.file_id, content_length)
debug_logger.info(f'split time: {end - start} {len(local_file.docs)}')
self.mysql_client.update_chunk_size(local_file.file_id, len(local_file.docs))
add_ids = await self.faiss_client.add_document(local_file.docs)
insert_time = time.time()
debug_logger.info(f'insert time: {insert_time - end}')
self.mysql_client.update_file_status(local_file.file_id, status='green')
self.mysql_client.update_file_status(local_file.file_id, status='green', reason=" ")
success_list.append(local_file)
debug_logger.info(
f"insert_to_faiss: success num: {len(success_list)}, failed num: {len(failed_list)}")
Expand All @@ -125,6 +138,7 @@ async def local_doc_search(self, query, kb_ids):
debug_logger.info(f"local doc search retrieval_documents: {retrieval_documents}")
return retrieval_documents

@get_time
def get_web_search(self, queries, top_k=None):
if not top_k:
top_k = self.top_k
Expand All @@ -133,11 +147,17 @@ def get_web_search(self, queries, top_k=None):
source_documents = []
for doc in web_documents:
doc.metadata['retrieval_query'] = query # 添加查询到文档的元数据中
file_name = re.sub(r'[\uFF01-\uFF5E\u3000-\u303F]', '', doc.metadata['title'])
doc.metadata['file_name'] = file_name + '.web'
doc.metadata['file_url'] = doc.metadata['source']
doc.metadata['embed_version'] = self.embeddings.embed_version
source_documents.append(doc)
if 'description' in doc.metadata:
desc_doc = Document(page_content=doc.metadata['description'], metadata=doc.metadata)
source_documents.append(desc_doc)
source_documents = self.web_splitter.split_documents(source_documents)
return web_content, source_documents



def web_page_search(self, query, top_k=None):
# 防止get_web_search调用失败,需要try catch
try:
Expand Down
8 changes: 4 additions & 4 deletions qanything_kernel/core/local_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def table_process(doc):
return new_docs

@staticmethod
def pdf_process(dos: List[Document]):
def pdf_process(docs: List[Document]):
new_docs = []
for doc in dos:
for doc in docs:
# metadata={'title_lst': ['#樊昊天个人简历', '##教育经历'], 'has_table': False}
title_lst = doc.metadata['title_lst']
# 删除所有仅有多个#的title
Expand Down Expand Up @@ -170,7 +170,7 @@ def split_file_to_docs(self, ocr_engine: Callable, sentence_size=SENTENCE_SIZE,
else:
try:
from qanything_kernel.utils.loader.self_pdf_loader import PdfLoader
loader = PdfLoader(filename=self.file_path, root_dir=os.path.dirname(self.file_path))
loader = PdfLoader(filename=self.file_path, save_dir=os.path.dirname(self.file_path))
markdown_dir = loader.load_to_markdown()
docs = convert_markdown_to_langchaindoc(markdown_dir)
docs = self.pdf_process(docs)
Expand Down Expand Up @@ -224,7 +224,7 @@ def split_file_to_docs(self, ocr_engine: Callable, sentence_size=SENTENCE_SIZE,
new_docs.append(doc)
else:
last_doc = new_docs[-1]
if len(last_doc.page_content) + len(doc.page_content) < min_length:
if num_tokens(last_doc.page_content) + num_tokens(doc.page_content) < min_length:
last_doc.page_content += '\n' + doc.page_content
else:
new_docs.append(doc)
Expand Down
2 changes: 1 addition & 1 deletion qanything_kernel/qanything_server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def list_docs(req: request):
else:
status_count[status] += 1
data.append({"file_id": file_info[0], "file_name": file_info[1], "status": file_info[2], "bytes": file_info[3],
"content_length": file_info[4], "timestamp": file_info[5], "msg": msg_map[file_info[2]]})
"content_length": file_info[4], "timestamp": file_info[5], "msg": file_info[6]})
file_name = file_info[1]
file_id = file_info[0]
if file_name.endswith('.faq'):
Expand Down
Loading

0 comments on commit 45a624b

Please sign in to comment.